提交 e590588a 编写于 作者: D dengkaipeng

fix for itnerpolate. test=develop

上级 b2dcdb51
...@@ -45,9 +45,14 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,14 @@ class InterpolateOp : public framework::OperatorWithKernel {
// round down // round down
out_h = static_cast<int>(dim_x[2] * scale); out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale); out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else { } else {
out_h = ctx->Attrs().Get<int>("out_h"); out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w"); out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
} }
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
...@@ -59,12 +64,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -59,12 +64,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
return; return;
} }
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w}); std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
} else {
ctx->SetOutputDim("Out", dim_x);
}
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册