提交 05eafcca 编写于 作者: C chenweihang

refine some messages and adjust data type

上级 8f2486ca
...@@ -30,13 +30,14 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -30,13 +30,14 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit. // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6, PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have " "Invalid dimnesions, the rank of Input(X) "
"between [1, 6] dimensions (Eigen limit)."); "should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) { for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(), PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank."); "The squeeze axis should be less than input "
"tensor's rank.");
} }
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
...@@ -50,30 +51,29 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -50,30 +51,29 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims, static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) { const framework::DDim &in_dims) {
int num_squeeze_dims = static_cast<int>(squeeze_dims.size()); size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0; int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false}; bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze. // Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed // Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) { if (num_squeeze_dims == 0) {
for (int idx = 0; idx < static_cast<int>(in_dims.size()); ++idx) { for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) { if (in_dims[idx] == 1) {
should_squeeze[idx] = true; should_squeeze[idx] = true;
++cnt_squeezed_dims; ++cnt_squeezed_dims;
} }
} }
} else { } else {
for (int idx = 0; idx < num_squeeze_dims; ++idx) { for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx]; : squeeze_dims[idx];
// Check current index. // Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0, PADDLE_ENFORCE(current >= 0,
"Invalid axis, negative axis is out of range."); "Invalid axis, the negative axis is out of range.");
// PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); PADDLE_ENFORCE(in_dims[current] == 1,
PADDLE_ENFORCE( "Invalid axis index, the axis that will be squeezed "
in_dims[current] == 1, "should equal 1.");
"Invalid axis index, the axis will be squeezed should be 1.");
if (!(should_squeeze[current])) { if (!(should_squeeze[current])) {
++cnt_squeezed_dims; ++cnt_squeezed_dims;
...@@ -84,8 +84,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -84,8 +84,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
// Make output dimensions // Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0); std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < static_cast<int>(in_dims.size()); for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
++in_idx) {
if (!should_squeeze[in_idx]) { if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx]; output_shape[out_idx++] = in_dims[in_idx];
} }
...@@ -123,7 +122,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -123,7 +122,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor). The input tensor of squeeze operator."); AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator."); AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes", AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of positive integers," "(std::vector<int>). List of integers,"
" indicate the dimensions to squeeze.") " indicate the dimensions to squeeze.")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("inplace", AddAttr<bool>("inplace",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册