diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index e570331aa05d435366982f149944724d652cbe42..dc15df2c3c1b8a2964312d983be8ce362d3ab95d 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -40,7 +40,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { "tensor's rank."); } - auto out_dims = GetOutputShape(axes, x_dims, ctx); + auto out_dims = GetOutputShape(axes, x_dims, false); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { // Only pass LoD when the first dimension of output and Input(X) @@ -51,7 +51,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector squeeze_dims, const framework::DDim &in_dims, - framework::InferShapeContext *ctx) { + bool is_runtime) { size_t num_squeeze_dims = squeeze_dims.size(); int cnt_squeezed_dims = 0; bool should_squeeze[9] = {false}; @@ -73,7 +73,7 @@ class SqueezeOpInferShape : public framework::InferShapeBase { PADDLE_ENFORCE(current >= 0, "Invalid axis, the negative axis is out of range."); - if (ctx->IsRuntime()) { + if (is_runtime) { PADDLE_ENFORCE(in_dims[current] == 1, "Invalid axis index, the axis that will be squeezed " "should be equal to 1."); @@ -108,7 +108,7 @@ class SqueezeOp : public framework::OperatorBase { const platform::Place &place) const override { auto &axes = Attr>("axes"); auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims); + auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims, true); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); @@ -228,7 +228,7 @@ class Squeeze2Op : public framework::OperatorBase { const platform::Place &place) const override { auto &axes = Attr>("axes"); auto x_dims = scope.FindVar(Input("X"))->Get().dims(); - auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims); + auto out_dims = Squeeze2OpInferShape::GetOutputShape(axes, x_dims, true); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims);