提交 190cfd69 编写于 作者: P phlrain

fix squeeze shape check; test=develop

上级 b7baeed7
...@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"tensor's rank."); "tensor's rank.");
} }
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims, ctx);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
...@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
} }
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims, static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) { const framework::DDim &in_dims,
framework::InferShapeContext *ctx) {
size_t num_squeeze_dims = squeeze_dims.size(); size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0; int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false}; bool should_squeeze[9] = {false};
...@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
// Check current index, the upper limit has beed checked in line 36. // Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0, PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range."); "Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed " if (ctx->IsRuntime()) {
"should be equal to 1."); PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
}
if (!(should_squeeze[current])) { if (!(should_squeeze[current])) {
++cnt_squeezed_dims; ++cnt_squeezed_dims;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册