diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc index f448ac99aced9eaecf156255cf570806edf8d7f3..3ff18200a0d92d426dd5466ee67e88b20d856a93 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc @@ -59,7 +59,10 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { } // CINN ops in this white list support 0D-Tensor - const std::unordered_set white_op_list{"elementwise_add"}; + const std::unordered_set white_op_list{"elementwise_add", + "elementwise_sub", + "elementwise_mul", + "elementwise_div"}; std::unordered_set white_tensor_name; // enable white_op_list only when graph_node_size = 1, which means single op // test diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 15880ec1fb2ed0776219226fd681a2863d2aefca..671ed2b4131d445a5f6ce3e72be792a5fa7d5089 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -143,9 +143,6 @@ class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp): self.x_shape = [] self.y_shape = [] - def if_enable_cinn(self): - self.enable_cinn = False - class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): def init_shape(self): @@ -161,9 +158,6 @@ class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): def compute_gradient_y(self, grad_out, out, y): return np.sum(-1 * grad_out * out / y.reshape([1, 1])) - def if_enable_cinn(self): - self.enable_cinn = False - class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp): def init_shape(self): @@ -179,9 +173,6 @@ class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp): def compute_gradient_y(self, grad_out, out, y): return -1 * grad_out * out / y - def if_enable_cinn(self): - self.enable_cinn = False - @unittest.skipIf( not core.is_compiled_with_cuda() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 79b9bcba7660fcf80c287cea8294c034eb6c95f7..0b773108a8c9083f81590ad90832dbc9911a9fe0 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -148,9 +148,6 @@ class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp): self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) self.out = np.multiply(self.x, self.y) - def if_enable_cinn(self): - self.enable_cinn = False - class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): def init_input_output(self): @@ -158,9 +155,6 @@ class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) self.out = np.multiply(self.x, self.y) - def if_enable_cinn(self): - self.enable_cinn = False - class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): def init_input_output(self): @@ -168,9 +162,6 @@ class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.out = np.multiply(self.x, self.y) - def if_enable_cinn(self): - self.enable_cinn = False - class TestBF16ElementwiseMulOp(OpTest): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index b979fa339de94745f22c059ca58ca6bfdf3bc909..59871057f0060fb22d7deffa5917a5195072dc8a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -135,7 +135,6 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp): } self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.if_check_prim() - self.if_enable_cinn() class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1): @@ -182,7 +181,6 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp): } self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.if_check_prim() - self.if_enable_cinn() class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2): @@ -229,7 +227,6 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp): } self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.if_check_prim() - self.if_enable_cinn() class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3):