未验证 提交 10c487eb 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix interpolate cu. test=develop (#17101)

上级 aca60e9a
......@@ -286,7 +286,7 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w - in_w* scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
......
......@@ -305,7 +305,7 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 2.
......@@ -316,7 +316,7 @@ class TestBilinearInterpScale1(TestBilinearInterpOp):
class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.
......@@ -327,7 +327,7 @@ class TestBilinearInterpScale2(TestBilinearInterpOp):
class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.5
......
......@@ -259,7 +259,7 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp):
class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = 2.
......@@ -270,7 +270,7 @@ class TestNearestNeighborInterpScale1(TestNearestInterpOp):
class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.input_shape = [3, 2, 5, 7]
self.out_h = 64
self.out_w = 32
self.scale = 1.5
......@@ -281,7 +281,7 @@ class TestNearestNeighborInterpScale2(TestNearestInterpOp):
class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = 1.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册