From 05eafcca7373d690916244db0890370d64d7b535 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 10 Jul 2018 08:41:36 +0000 Subject: [PATCH] refine some messages and adjust data type --- paddle/fluid/operators/squeeze_op.cc | 29 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 1b656ea138e..805f198bf3c 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -30,13 +30,14 @@ class SqueezeOpInferShape : public framework::InferShapeBase { const auto &x_dims = ctx->GetInputDim("X"); // Check input tensor dims (<6) Eigen limit. PADDLE_ENFORCE(x_dims.size() <= 6, - "Invalid dimnesions, dynamic dimensions must have " - "between [1, 6] dimensions (Eigen limit)."); + "Invalid dimnesions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)."); const auto &axes = ctx->Attrs().Get>("axes"); for (int a : axes) { PADDLE_ENFORCE_LT(a, x_dims.size(), - "The axis must be less than input tensor's rank."); + "The squeeze axis should be less than input " + "tensor's rank."); } auto out_dims = GetOutputShape(axes, x_dims); @@ -50,30 +51,29 @@ class SqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector squeeze_dims, const framework::DDim &in_dims) { - int num_squeeze_dims = static_cast(squeeze_dims.size()); + size_t num_squeeze_dims = squeeze_dims.size(); int cnt_squeezed_dims = 0; bool should_squeeze[9] = {false}; // Determines number of dimensions of output tensor after squeeze. // Mark and count the dimensions need to be squeezed if (num_squeeze_dims == 0) { - for (int idx = 0; idx < static_cast(in_dims.size()); ++idx) { + for (int idx = 0; idx < in_dims.size(); ++idx) { if (in_dims[idx] == 1) { should_squeeze[idx] = true; ++cnt_squeezed_dims; } } } else { - for (int idx = 0; idx < num_squeeze_dims; ++idx) { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() : squeeze_dims[idx]; - // Check current index. + // Check current index, the upper limit has beed checked in line 36. PADDLE_ENFORCE(current >= 0, - "Invalid axis, negative axis is out of range."); - // PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); - PADDLE_ENFORCE( - in_dims[current] == 1, - "Invalid axis index, the axis will be squeezed should be 1."); + "Invalid axis, the negative axis is out of range."); + PADDLE_ENFORCE(in_dims[current] == 1, + "Invalid axis index, the axis that will be squeezed " + "should equal 1."); if (!(should_squeeze[current])) { ++cnt_squeezed_dims; @@ -84,8 +84,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { // Make output dimensions std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); - for (int in_idx = 0, out_idx = 0; in_idx < static_cast(in_dims.size()); - ++in_idx) { + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { if (!should_squeeze[in_idx]) { output_shape[out_idx++] = in_dims[in_idx]; } @@ -123,7 +122,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor). The input tensor of squeeze operator."); AddOutput("Out", "(Tensor). The output tensor of squeeze operator."); AddAttr>("axes", - "(std::vector). List of positive integers," + "(std::vector). List of integers," " indicate the dimensions to squeeze.") .SetDefault({}); AddAttr("inplace", -- GitLab