diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index ecfb4e89566f3d72b3c262946c370bf34ce7515a..e570331aa05d435366982f149944724d652cbe42 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { "tensor's rank."); } - auto out_dims = GetOutputShape(axes, x_dims); + auto out_dims = GetOutputShape(axes, x_dims, ctx); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { // Only pass LoD when the first dimension of output and Input(X) @@ -50,7 +50,8 @@ class SqueezeOpInferShape : public framework::InferShapeBase { } static framework::DDim GetOutputShape(const std::vector squeeze_dims, - const framework::DDim &in_dims) { + const framework::DDim &in_dims, + framework::InferShapeContext *ctx) { size_t num_squeeze_dims = squeeze_dims.size(); int cnt_squeezed_dims = 0; bool should_squeeze[9] = {false}; @@ -71,9 +72,12 @@ class SqueezeOpInferShape : public framework::InferShapeBase { // Check current index, the upper limit has beed checked in line 36. PADDLE_ENFORCE(current >= 0, "Invalid axis, the negative axis is out of range."); - PADDLE_ENFORCE(in_dims[current] == 1, - "Invalid axis index, the axis that will be squeezed " - "should be equal to 1."); + + if (ctx->IsRuntime()) { + PADDLE_ENFORCE(in_dims[current] == 1, + "Invalid axis index, the axis that will be squeezed " + "should be equal to 1."); + } if (!(should_squeeze[current])) { ++cnt_squeezed_dims;