未验证 提交 31fc763a 编写于 作者: W wangzhen38 提交者: GitHub

[CINN] fix concat (#52341)

* [CINN] fix concat&pow

* update concat

* composite_backward_api

* for ci

* for ci

* update test & fix opmaker
上级 b8a848bb
......@@ -165,11 +165,11 @@ class ConcatCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
paddle::optional<paddle::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("AxisTensor");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
std::vector<paddle::Tensor> input_grad = this->GetMultiForwardInput("X");
std::vector<paddle::Tensor> input_grad = this->GetMultiInputGrad("X");
std::vector<paddle::Tensor *> input_grad_ptr(input_grad.size());
for (auto sub_tensor : input_grad) {
input_grad_ptr.push_back(&sub_tensor);
std::vector<paddle::Tensor *> input_grad_ptr;
for (auto i = 0; i < static_cast<int>(input_grad.size()); ++i) {
input_grad_ptr.push_back(&input_grad[i]);
}
int axis = static_cast<int>(this->Attr<int>("axis"));
std::vector<paddle::Tensor *> dx_ptr = this->GetOutputPtr(input_grad_ptr);
......@@ -215,10 +215,10 @@ REGISTER_OPERATOR(concat,
ops::ConcatOpMaker,
ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatCompositeGradOpMaker,
ConcatInferShapeFunctor);
REGISTER_OPERATOR(concat_grad,
ops::ConcatOpGrad,
ops::ConcatCompositeGradOpMaker,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatOpGradNoNeedBufferVarInferer);
......@@ -2040,6 +2040,11 @@
outputs :
out : Out
- op : split
int_array:
sections :
data_type : int
- op : sqrt
backward : sqrt_grad, sqrt_double_grad (sqrt_grad_grad)
inputs :
......
......@@ -136,7 +136,6 @@ class TestConcatOp6(TestConcatOp):
self.dtype = self.get_dtype()
self.python_api = paddle.concat
self.public_python_api = paddle.concat
self.prim_op_type = "prim"
self.enable_cinn = False
self.init_test_data()
self.lod = [[20, 80]]
......@@ -227,11 +226,9 @@ def create_test_AxisTensor(parent):
self.op_type = "concat"
self.python_api = paddle.concat
self.public_python_api = paddle.concat
self.enable_cinn = False
self.dtype = self.get_dtype()
self.init_test_data()
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = {
'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)],
'AxisTensor': np.array([self.axis]).astype("int32"),
......@@ -252,6 +249,22 @@ def create_test_AxisTensor(parent):
)
}
def test_check_grad(self):
if (
parent.__name__ == 'TestConcatOp4'
or parent.__name__ == 'TestConcatOp3'
):
return
if self.dtype == np.uint16:
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
else:
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "AxisTensor")
TestConcatAxisTensor.__name__ = cls_name
globals()[cls_name] = TestConcatAxisTensor
......@@ -269,6 +282,47 @@ create_test_AxisTensor(TestConcatOp6)
def create_test_fp16(parent):
class TestConcatFp16(parent):
def setUp(self):
self.op_type = "concat"
self.python_api = paddle.concat
self.public_python_api = paddle.concat
self.enable_cinn = False
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {
'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]
}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = (
self.actual_axis if self.actual_axis > 0 else 0
)
else:
self.actual_axis = self.axis
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.actual_axis
)
}
def test_check_grad(self):
if (
parent.__name__ == 'TestConcatOp4'
or parent.__name__ == 'TestConcatOp3'
):
return
if self.dtype == np.uint16:
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
else:
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
def get_dtype(self):
return np.float16
......@@ -291,6 +345,47 @@ def create_test_bf16(parent):
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestConcatBf16(parent):
def setUp(self):
self.op_type = "concat"
self.python_api = paddle.concat
self.public_python_api = paddle.concat
self.enable_cinn = False
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {
'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]
}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = (
self.actual_axis if self.actual_axis > 0 else 0
)
else:
self.actual_axis = self.axis
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.actual_axis
)
}
def test_check_grad(self):
if (
parent.__name__ == 'TestConcatOp4'
or parent.__name__ == 'TestConcatOp3'
):
return
if self.dtype == np.uint16:
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
else:
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
def get_dtype(self):
return np.uint16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册