未验证 提交 7bf97d2c 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] Support elementwise_sub,mul,div (#54064)

上级 48edbb28
......@@ -59,7 +59,10 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
}
// CINN ops in this white list support 0D-Tensor
const std::unordered_set<std::string> white_op_list{"elementwise_add"};
const std::unordered_set<std::string> white_op_list{"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div"};
std::unordered_set<std::string> white_tensor_name;
// enable white_op_list only when graph_node_size = 1, which means single op
// test
......
......@@ -143,9 +143,6 @@ class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp):
self.x_shape = []
self.y_shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def init_shape(self):
......@@ -161,9 +158,6 @@ class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y.reshape([1, 1]))
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp):
def init_shape(self):
......@@ -179,9 +173,6 @@ class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp):
def compute_gradient_y(self, grad_out, out, y):
return -1 * grad_out * out / y
def if_enable_cinn(self):
self.enable_cinn = False
@unittest.skipIf(
not core.is_compiled_with_cuda()
......
......@@ -148,9 +148,6 @@ class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp):
def init_input_output(self):
......@@ -158,9 +155,6 @@ class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp):
def init_input_output(self):
......@@ -168,9 +162,6 @@ class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp):
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
class TestBF16ElementwiseMulOp(OpTest):
def setUp(self):
......
......@@ -135,7 +135,6 @@ class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim1(TestElementwiseSubOp_ZeroDim1):
......@@ -182,7 +181,6 @@ class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim2(TestElementwiseSubOp_ZeroDim2):
......@@ -229,7 +227,6 @@ class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
self.if_check_prim()
self.if_enable_cinn()
class TestElementwiseSubFP16OP_ZeroDim3(TestElementwiseSubOp_ZeroDim3):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册