未验证 提交 7bf97d2c 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] Support elementwise_sub,mul,div (#54064)

上级 48edbb28
...@@ -59,7 +59,10 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -59,7 +59,10 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
} }
// CINN ops in this white list support 0D-Tensor // CINN ops in this white list support 0D-Tensor
const std::unordered_set<std::string> white_op_list{"elementwise_add"}; const std::unordered_set<std::string> white_op_list{"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div"};
std::unordered_set<std::string> white_tensor_name; std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op // enable white_op_list only when graph_node_size = 1, which means single op
// test // test
......
...@@ -143,9 +143,6 @@ class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp): ...@@ -143,9 +143,6 @@ class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp):
self.x_shape = [] self.x_shape = []
self.y_shape = [] self.y_shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def init_shape(self): def init_shape(self):
...@@ -161,9 +158,6 @@ class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): ...@@ -161,9 +158,6 @@ class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def compute_gradient_y(self, grad_out, out, y): def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y.reshape([1, 1])) return np.sum(-1 * grad_out * out / y.reshape([1, 1]))
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp): class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp):
def init_shape(self): def init_shape(self):
...@@ -179,9 +173,6 @@ class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp): ...@@ -179,9 +173,6 @@ class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp):
def compute_gradient_y(self, grad_out, out, y): def compute_gradient_y(self, grad_out, out, y):
return -1 * grad_out * out / y return -1 * grad_out * out / y
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
......
...@@ -148,9 +148,6 @@ class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp): ...@@ -148,9 +148,6 @@ class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.multiply(self.x, self.y) self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
...@@ -158,9 +155,6 @@ class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): ...@@ -158,9 +155,6 @@ class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype) self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.multiply(self.x, self.y) self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp):
def init_input_output(self): def init_input_output(self):
...@@ -168,9 +162,6 @@ class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): ...@@ -168,9 +162,6 @@ class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.multiply(self.x, self.y) self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestBF16ElementwiseMulOp(OpTest): class TestBF16ElementwiseMulOp(OpTest):
def setUp(self): def setUp(self):
......
...@@ -135,7 +135,6 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp): ...@@ -135,7 +135,6 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1): class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1):
...@@ -182,7 +181,6 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp): ...@@ -182,7 +181,6 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2): class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2):
...@@ -229,7 +227,6 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp): ...@@ -229,7 +227,6 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim() self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3): class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册