From cef8dbc1f7867d013046227f8283ee249bda8a0f Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 10 Jul 2018 09:09:55 +0000 Subject: [PATCH] refine some messages and adjust data type --- paddle/fluid/operators/unsqueeze_op.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 5e089d77f..da542aa85 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, - "Invalid dimensions, dynamic dimensions should within " - "[1, 6] dimensions (Eigen limit)."); + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { @@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { - int output_size = static_cast(in_dims.size() + unsqz_dims.size()); - int cur_output_size = static_cast(in_dims.size()); + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); std::vector output_shape(output_size, 0); // Validity Check: rank range. @@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", - "(std::vector). List of positive integers," + "(std::vector). List of integers," " indicate the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { - PADDLE_ENFORCE( - !axes.empty(), - "The unsqueeze axes information must be set by Attr(axes)."); + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, "Invalid dimensions, dynamic dimensions should within " -- GitLab