未验证 提交 10c487eb 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix interpolate cu. test=develop (#17101)

上级 aca60e9a
...@@ -286,7 +286,7 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -286,7 +286,7 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale"); float scale = ctx.Attr<float>("scale");
if (scale > 0) { if (scale > 0) {
out_h = in_h * scale; out_h = in_h * scale;
out_w - in_w* scale; out_w = in_w * scale;
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
......
...@@ -305,7 +305,7 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp): ...@@ -305,7 +305,7 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
class TestBilinearInterpScale1(TestBilinearInterpOp): class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 2. self.scale = 2.
...@@ -316,7 +316,7 @@ class TestBilinearInterpScale1(TestBilinearInterpOp): ...@@ -316,7 +316,7 @@ class TestBilinearInterpScale1(TestBilinearInterpOp):
class TestBilinearInterpScale2(TestBilinearInterpOp): class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 1. self.scale = 1.
...@@ -327,7 +327,7 @@ class TestBilinearInterpScale2(TestBilinearInterpOp): ...@@ -327,7 +327,7 @@ class TestBilinearInterpScale2(TestBilinearInterpOp):
class TestBilinearInterpScale3(TestBilinearInterpOp): class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 1.5 self.scale = 1.5
......
...@@ -259,7 +259,7 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp): ...@@ -259,7 +259,7 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp):
class TestNearestNeighborInterpScale1(TestNearestInterpOp): class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 7, 5]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 2. self.scale = 2.
...@@ -270,7 +270,7 @@ class TestNearestNeighborInterpScale1(TestNearestInterpOp): ...@@ -270,7 +270,7 @@ class TestNearestNeighborInterpScale1(TestNearestInterpOp):
class TestNearestNeighborInterpScale2(TestNearestInterpOp): class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 5, 7]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 1.5 self.scale = 1.5
...@@ -281,7 +281,7 @@ class TestNearestNeighborInterpScale2(TestNearestInterpOp): ...@@ -281,7 +281,7 @@ class TestNearestNeighborInterpScale2(TestNearestInterpOp):
class TestNearestNeighborInterpScale3(TestNearestInterpOp): class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 7, 5]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 1. self.scale = 1.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册