diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 9d9924a6baadb9a7a6c2e098f83d833b93d8870c..0d43213f96043e0541c47151f33ce69ff4d3ccc2 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel { int out_h = ctx->Attrs().Get("out_h"); int out_w = ctx->Attrs().Get("out_w"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); + PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); + PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { auto out_size_dim = ctx->GetInputDim("OutSize"); @@ -50,12 +52,8 @@ class InterpolateOp : public framework::OperatorWithKernel { return; } - if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) { - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); - } else { - ctx->SetOutputDim("Out", dim_x); - } + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); } protected: