提交 534beb5a 编写于 作者: D dengkaipeng

fix for itnerpolate. test=release/1.4

上级 70a967df
......@@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
......@@ -50,12 +52,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
return;
}
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} else {
ctx->SetOutputDim("Out", dim_x);
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册