diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc index f567f4660534c7ffcbfffadbccac29ff68c8b648..833535eb878e9ee07b37b9550bd98fa7d6070342 100644 --- a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc @@ -176,11 +176,15 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace, - ops::InterpolateMKLDNNKernel); + ops::InterpolateMKLDNNKernel, + ops::InterpolateMKLDNNKernel, + ops::InterpolateMKLDNNKernel); REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel); REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, - ops::InterpolateMKLDNNKernel); + ops::InterpolateMKLDNNKernel, + ops::InterpolateMKLDNNKernel, + ops::InterpolateMKLDNNKernel); REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py index 9f39826cb3ed2875993452269a66559ec2e84782..a802ef4c61285f54a6280fd3638aa71681aefaf0 100755 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py @@ -16,9 +16,6 @@ from __future__ import print_function import unittest import numpy as np -import paddle -import paddle.fluid.core as core -import paddle.fluid as fluid from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci @@ -66,6 +63,9 @@ class TestNearestInterpMKLDNNOp(OpTest): def init_test_case(self): pass + def init_data_type(self): + pass + def setUp(self): self.op_type = "nearest_interp" self.interp_method = 'nearest' @@ -73,6 +73,7 @@ class TestNearestInterpMKLDNNOp(OpTest): self.use_mkldnn = True self.input_shape = [1, 1, 2, 2] self.data_layout = 'NCHW' + self.dtype = np.float32 # priority: actual_shape > out_size > scale > out_h & out_w self.out_h = 1 self.out_w = 1 @@ -81,8 +82,15 @@ class TestNearestInterpMKLDNNOp(OpTest): self.actual_shape = None self.init_test_case() + self.init_data_type() + + if self.dtype == np.float32: + input_np = np.random.random(self.input_shape).astype(self.dtype) + else: + init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10) + input_np = np.random.randint(init_low, init_high, + self.input_shape).astype(self.dtype) - input_np = np.random.random(self.input_shape).astype("float32") if self.data_layout == "NCHW": in_h = self.input_shape[2] in_w = self.input_shape[3] @@ -162,6 +170,35 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp): self.scale = 0. +def create_test_class(parent): + class TestFp32Case(parent): + def init_data_type(self): + self.dtype = np.float32 + + class TestInt8Case(parent): + def init_data_type(self): + self.dtype = np.int8 + + class TestUint8Case(parent): + def init_data_type(self): + self.dtype = np.uint8 + + TestFp32Case.__name__ = parent.__name__ + TestInt8Case.__name__ = parent.__name__ + TestUint8Case.__name__ = parent.__name__ + globals()[parent.__name__] = TestFp32Case + globals()[parent.__name__] = TestInt8Case + globals()[parent.__name__] = TestUint8Case + + +create_test_class(TestNearestInterpMKLDNNOp) +create_test_class(TestNearestInterpOpMKLDNNNHWC) +create_test_class(TestNearestNeighborInterpMKLDNNCase2) +create_test_class(TestNearestNeighborInterpCase3) +create_test_class(TestNearestNeighborInterpCase4) +create_test_class(TestNearestInterpOpMKLDNNNHWC) +create_test_class(TestNearestNeighborInterpSame) + if __name__ == "__main__": from paddle import enable_static enable_static() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py index b608ca3af2f3660347278135e1118bb3a3c817d5..24ebf40216f4bad6f351f865b86662ce0718f690 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py @@ -16,9 +16,6 @@ from __future__ import print_function import unittest import numpy as np -import paddle -import paddle.fluid.core as core -import paddle.fluid as fluid from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci @@ -66,6 +63,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest): def init_test_case(self): pass + def init_data_type(self): + pass + def setUp(self): self.op_type = "nearest_interp_v2" self.interp_method = 'nearest' @@ -73,6 +73,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest): self.use_mkldnn = True self.input_shape = [1, 1, 2, 2] self.data_layout = 'NCHW' + self.dtype = np.float32 # priority: actual_shape > out_size > scale > out_h & out_w self.out_h = 1 self.out_w = 1 @@ -81,8 +82,15 @@ class TestNearestInterpV2MKLDNNOp(OpTest): self.actual_shape = None self.init_test_case() + self.init_data_type() + + if self.dtype == np.float32: + input_np = np.random.random(self.input_shape).astype(self.dtype) + else: + init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10) + input_np = np.random.randint(init_low, init_high, + self.input_shape).astype(self.dtype) - input_np = np.random.random(self.input_shape).astype("float32") if self.data_layout == "NCHW": in_h = self.input_shape[2] in_w = self.input_shape[3] @@ -178,6 +186,34 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp): self.out_size = np.array([65, 129]).astype("int32") +def create_test_class(parent): + class TestFp32Case(parent): + def init_data_type(self): + self.dtype = np.float32 + + class TestInt8Case(parent): + def init_data_type(self): + self.dtype = np.int8 + + class TestUint8Case(parent): + def init_data_type(self): + self.dtype = np.uint8 + + TestFp32Case.__name__ = parent.__name__ + TestInt8Case.__name__ = parent.__name__ + TestUint8Case.__name__ = parent.__name__ + globals()[parent.__name__] = TestFp32Case + globals()[parent.__name__] = TestInt8Case + globals()[parent.__name__] = TestUint8Case + + +create_test_class(TestNearestInterpV2MKLDNNOp) +create_test_class(TestNearestInterpOpV2MKLDNNNHWC) +create_test_class(TestNearestNeighborInterpV2MKLDNNCase2) +create_test_class(TestNearestNeighborInterpV2MKLDNNCase3) +create_test_class(TestNearestNeighborInterpV2MKLDNNCase4) +create_test_class(TestNearestNeighborInterpV2MKLDNNSame) + if __name__ == "__main__": from paddle import enable_static enable_static()