提交 cef8dbc1 编写于 作者: C chenweihang

refine some messages and adjust data type

上级 7526eaaf
...@@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(static_cast<int>(x_dims.size()) <= 6, PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimensions, dynamic dimensions should within " "Invalid dimensions, the rank of Input(X) "
"[1, 6] dimensions (Eigen limit)."); "should be in the range of [1, 6] (Eigen limit)");
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
...@@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { ...@@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims, static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
int output_size = static_cast<int>(in_dims.size() + unsqz_dims.size()); int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = static_cast<int>(in_dims.size()); int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0); std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range. // Validity Check: rank range.
...@@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
AddAttr<std::vector<int>>("axes", AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of positive integers," "(std::vector<int>). List of integers,"
" indicate the dimensions to be inserted") " indicate the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) { .AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE( PADDLE_ENFORCE(!axes.empty(),
!axes.empty(), "Invalid axes, The unsqueeze axes is empty.");
"The unsqueeze axes information must be set by Attr(axes).");
// Validity Check: axes dims (<6). // Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6, PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should within " "Invalid dimensions, dynamic dimensions should within "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册