未验证 提交 56e2a6a6 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add nearest_interp/v2 int8 and uint8 support (#37985)

上级 abb07f35
...@@ -176,11 +176,15 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -176,11 +176,15 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>); ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>); ops::InterpolateMKLDNNKernel<float>);
REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>); ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>); ops::InterpolateMKLDNNKernel<float>);
...@@ -16,9 +16,6 @@ from __future__ import print_function ...@@ -16,9 +16,6 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
...@@ -66,6 +63,9 @@ class TestNearestInterpMKLDNNOp(OpTest): ...@@ -66,6 +63,9 @@ class TestNearestInterpMKLDNNOp(OpTest):
def init_test_case(self): def init_test_case(self):
pass pass
def init_data_type(self):
pass
def setUp(self): def setUp(self):
self.op_type = "nearest_interp" self.op_type = "nearest_interp"
self.interp_method = 'nearest' self.interp_method = 'nearest'
...@@ -73,6 +73,7 @@ class TestNearestInterpMKLDNNOp(OpTest): ...@@ -73,6 +73,7 @@ class TestNearestInterpMKLDNNOp(OpTest):
self.use_mkldnn = True self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2] self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w # priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
...@@ -81,8 +82,15 @@ class TestNearestInterpMKLDNNOp(OpTest): ...@@ -81,8 +82,15 @@ class TestNearestInterpMKLDNNOp(OpTest):
self.actual_shape = None self.actual_shape = None
self.init_test_case() self.init_test_case()
self.init_data_type()
if self.dtype == np.float32:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
input_np = np.random.randint(init_low, init_high,
self.input_shape).astype(self.dtype)
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW": if self.data_layout == "NCHW":
in_h = self.input_shape[2] in_h = self.input_shape[2]
in_w = self.input_shape[3] in_w = self.input_shape[3]
...@@ -162,6 +170,35 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp): ...@@ -162,6 +170,35 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp):
self.scale = 0. self.scale = 0.
def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32
class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
create_test_class(TestNearestInterpMKLDNNOp)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpMKLDNNCase2)
create_test_class(TestNearestNeighborInterpCase3)
create_test_class(TestNearestNeighborInterpCase4)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpSame)
if __name__ == "__main__": if __name__ == "__main__":
from paddle import enable_static from paddle import enable_static
enable_static() enable_static()
......
...@@ -16,9 +16,6 @@ from __future__ import print_function ...@@ -16,9 +16,6 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
...@@ -66,6 +63,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -66,6 +63,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self): def init_test_case(self):
pass pass
def init_data_type(self):
pass
def setUp(self): def setUp(self):
self.op_type = "nearest_interp_v2" self.op_type = "nearest_interp_v2"
self.interp_method = 'nearest' self.interp_method = 'nearest'
...@@ -73,6 +73,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -73,6 +73,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.use_mkldnn = True self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2] self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w # priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
...@@ -81,8 +82,15 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -81,8 +82,15 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.actual_shape = None self.actual_shape = None
self.init_test_case() self.init_test_case()
self.init_data_type()
if self.dtype == np.float32:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
input_np = np.random.randint(init_low, init_high,
self.input_shape).astype(self.dtype)
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW": if self.data_layout == "NCHW":
in_h = self.input_shape[2] in_h = self.input_shape[2]
in_w = self.input_shape[3] in_w = self.input_shape[3]
...@@ -178,6 +186,34 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp): ...@@ -178,6 +186,34 @@ class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32
class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
create_test_class(TestNearestInterpV2MKLDNNOp)
create_test_class(TestNearestInterpOpV2MKLDNNNHWC)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase2)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase3)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase4)
create_test_class(TestNearestNeighborInterpV2MKLDNNSame)
if __name__ == "__main__": if __name__ == "__main__":
from paddle import enable_static from paddle import enable_static
enable_static() enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册