提交 cef8dbc1 编写于 作者: C chenweihang

refine some messages and adjust data type

上级 7526eaaf
......@@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6,
"Invalid dimensions, dynamic dimensions should within "
"[1, 6] dimensions (Eigen limit).");
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)");
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
......@@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) {
int output_size = static_cast<int>(in_dims.size() + unsqz_dims.size());
int cur_output_size = static_cast<int>(in_dims.size());
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
......@@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of positive integers,"
"(std::vector<int>). List of integers,"
" indicate the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(
!axes.empty(),
"The unsqueeze axes information must be set by Attr(axes).");
PADDLE_ENFORCE(!axes.empty(),
"Invalid axes, The unsqueeze axes is empty.");
// Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should within "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册