提交 05eafcca 编写于 作者: C chenweihang

refine some messages and adjust data type

上级 8f2486ca
......@@ -30,13 +30,14 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have "
"between [1, 6] dimensions (Eigen limit).");
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank.");
"The squeeze axis should be less than input "
"tensor's rank.");
}
auto out_dims = GetOutputShape(axes, x_dims);
......@@ -50,30 +51,29 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
int num_squeeze_dims = static_cast<int>(squeeze_dims.size());
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < static_cast<int>(in_dims.size()); ++idx) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (int idx = 0; idx < num_squeeze_dims; ++idx) {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index.
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, negative axis is out of range.");
// PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given.");
PADDLE_ENFORCE(
in_dims[current] == 1,
"Invalid axis index, the axis will be squeezed should be 1.");
"Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should equal 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
......@@ -84,8 +84,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < static_cast<int>(in_dims.size());
++in_idx) {
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
......@@ -123,7 +122,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of positive integers,"
"(std::vector<int>). List of integers,"
" indicate the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册