From f820573b9c6ffee12aaf64b656d902dc0c9532f5 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Wed, 7 Nov 2018 11:37:27 +0100 Subject: [PATCH] MKLDNN elementwise_mul: Add UTs --- .../test_elementwise_mul_mkldnn_op.py | 119 +++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index a89f439664..a008979801 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -49,7 +49,37 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass -class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): +@unittest.skip("Not implemented yet.") +class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 8, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + self.y = np.random.rand(1, 8).astype(self.dtype) + + self.out = x * self.y.reshape(1, 8, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp() + self.attrs["x_data_format"] = "nchw8c" + self.attrs["y_data_format"] = "nc" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): def init_input_output(self): self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.y = np.random.rand(1, 16).astype(self.dtype) @@ -71,5 +101,92 @@ class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass +class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +@unittest.skip("Not implemented yet.") +class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + if __name__ == '__main__': unittest.main() -- GitLab