未验证 提交 6768c6ec 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】Concat bug (#53350)

* modify concat_grad add sum comp rule

* modify opcompat
上级 c0ee14f6
......@@ -470,7 +470,7 @@ void concat_grad(const std::vector<Tensor>& x,
sections.push_back(x[i].dims()[axis_value]);
}
std::vector<Tensor> x_grad_tmp =
split<T>(out_grad, phi::IntArray(sections), axis);
split<T>(out_grad, phi::IntArray(sections), axis_value);
for (int i = 0; i < x_num; ++i) {
set_output<T>(x_grad_tmp.at(i), x_grad.at(i));
}
......
......@@ -95,6 +95,12 @@
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : add_n (sum)
inputs:
{inputs : X}
outputs:
{out : Out}
- op : addmm
backward : addmm_grad
inputs :
......
......@@ -1113,7 +1113,8 @@ set(TEST_CINN_OPS
test_dropout_op
test_group_norm_op
test_tile_op
test_roll_op)
test_roll_op
test_sum_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -45,6 +45,8 @@ class TestSumOp(OpTest):
def setUp(self):
self.op_type = "sum"
self.python_api = sum_wrapper
self.public_python_api = paddle.add_n
self.prim_op_type = "comp"
self.init_kernel_type()
self.use_mkldnn = False
self.init_kernel_type()
......@@ -60,10 +62,10 @@ class TestSumOp(OpTest):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x0'], 'Out', check_prim=True)
class TestSelectedRowsSumOp(unittest.TestCase):
......
......@@ -679,6 +679,14 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
return out, ret_mean_, ret_var_
@REGISTER_COMPOSITE('sum')
def sum_composite(x):
ans = 0
for xi in x:
ans += xi
return ans
@REGISTER_COMPOSITE('leaky_relu')
def leaky_relu_composite(x, negative_slope):
"""define composite rule of op leaky_relu."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册