diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 5e089d77f438e3404378af2f80598f926ca729f8..da542aa852b04d99f25d804ea328fd539cc2152e 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, - "Invalid dimensions, dynamic dimensions should within " - "[1, 6] dimensions (Eigen limit)."); + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { @@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { - int output_size = static_cast(in_dims.size() + unsqz_dims.size()); - int cur_output_size = static_cast(in_dims.size()); + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); std::vector output_shape(output_size, 0); // Validity Check: rank range. @@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", - "(std::vector). List of positive integers," + "(std::vector). List of integers," " indicate the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { - PADDLE_ENFORCE( - !axes.empty(), - "The unsqueeze axes information must be set by Attr(axes)."); + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, "Invalid dimensions, dynamic dimensions should within "