diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 16291a06dcf09ab5cf4506c76deff9d2983f32b5..900b0c636ddafc8c033560adf58d596eb696621f 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -45,9 +45,14 @@ class InterpolateOp : public framework::OperatorWithKernel { // round down out_h = static_cast(dim_x[2] * scale); out_w = static_cast(dim_x[3] * scale); + // protect when input shape is -1 + out_h = out_h > 0 ? out_h : -1; + out_w = out_w > 0 ? out_w : -1; } else { out_h = ctx->Attrs().Get("out_h"); out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); + PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); } if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { @@ -59,12 +64,8 @@ class InterpolateOp : public framework::OperatorWithKernel { return; } - if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) { - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); - } else { - ctx->SetOutputDim("Out", dim_x); - } + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); } protected: