提交 0f7411a1 编写于 作者: D dengkaipeng

round down for scale. test=develop

上级 2078f420
......@@ -41,8 +41,9 @@ class InterpolateOp : public framework::OperatorWithKernel {
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
out_h = dim_x[2] * scale;
out_w = dim_x[3] * scale;
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
......
......@@ -174,8 +174,8 @@ class InterpolateKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
......@@ -239,8 +239,8 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册