提交 0f7411a1 编写于 作者: D dengkaipeng

round down for scale. test=develop

上级 2078f420
...@@ -41,8 +41,9 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,9 @@ class InterpolateOp : public framework::OperatorWithKernel {
int out_h, out_w; int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale"); float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) { if (scale > 0) {
out_h = dim_x[2] * scale; // round down
out_w = dim_x[3] * scale; out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
} else { } else {
out_h = ctx->Attrs().Get<int>("out_h"); out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w"); out_w = ctx->Attrs().Get<int>("out_w");
......
...@@ -174,8 +174,8 @@ class InterpolateKernel : public framework::OpKernel<T> { ...@@ -174,8 +174,8 @@ class InterpolateKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale"); float scale = ctx.Attr<float>("scale");
if (scale > 0) { if (scale > 0) {
out_h = in_h * scale; out_h = static_cast<int>(in_h * scale);
out_w = in_w * scale; out_w = static_cast<int>(in_w * scale);
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
...@@ -239,8 +239,8 @@ class InterpolateGradKernel : public framework::OpKernel<T> { ...@@ -239,8 +239,8 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale"); float scale = ctx.Attr<float>("scale");
if (scale > 0) { if (scale > 0) {
out_h = in_h * scale; out_h = static_cast<int>(in_h * scale);
out_w = in_w * scale; out_w = static_cast<int>(in_w * scale);
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册