未验证 提交 a4689c90 编写于 作者: W wangzhen38 提交者: GitHub

Add concat grad cinn (#50972)

* [cinn] concat_grad

* [cinn] concat_grad

* [cinn] concat_grad build success

* [Add PGLBOX] fix unnitest

* [Add PGLBOX] fix unnitest

* [Add PGLBOX] fix codestyle

* [cinn] update by comments

* [cinn] update by comment

* [cinn] add axis check
上级 d1dd7302
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
...@@ -153,6 +156,41 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -153,6 +156,41 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class ConcatCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
std::vector<paddle::experimental::Tensor> input =
this->GetMultiForwardInput("X");
paddle::optional<paddle::experimental::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("AxisTensor");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
std::vector<paddle::experimental::Tensor> input_grad =
this->GetMultiForwardInput("X");
std::vector<paddle::experimental::Tensor *> input_grad_ptr(
input_grad.size());
for (auto sub_tensor : input_grad) {
input_grad_ptr.push_back(&sub_tensor);
}
int axis = static_cast<int>(this->Attr<int>("axis"));
std::vector<paddle::experimental::Tensor *> dx_ptr =
this->GetOutputPtr(input_grad_ptr);
std::vector<std::string> dx_name = this->GetOutputName(input_grad);
VLOG(6) << "Runing concat_grad composite func";
if (tensor_axis.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index from tensor for concat composite "
"grad for now. "));
} else {
prim::concat_grad<prim::DescTensor>(input, out_grad, axis, dx_ptr);
}
this->RecoverOutputName(input_grad, dx_name);
}
};
template <typename T> template <typename T>
class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -184,6 +222,7 @@ REGISTER_OPERATOR(concat, ...@@ -184,6 +222,7 @@ REGISTER_OPERATOR(concat,
ConcatInferShapeFunctor); ConcatInferShapeFunctor);
REGISTER_OPERATOR(concat_grad, REGISTER_OPERATOR(concat_grad,
ops::ConcatOpGrad, ops::ConcatOpGrad,
ops::ConcatCompositeGradOpMaker,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>, ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>, ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatOpGradNoNeedBufferVarInferer); ops::ConcatOpGradNoNeedBufferVarInferer);
...@@ -284,6 +284,29 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -284,6 +284,29 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
} }
} }
template <typename T>
void concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis,
std::vector<Tensor*> x_grad) {
int axis_value = axis.to<int>();
int rank = x[0].dims().size();
if (axis_value < 0) {
axis_value = axis_value + rank;
}
axis_value = axis_value > 0 ? axis_value : 0;
std::vector<int> sections;
int x_num = x.size();
for (int i = 0; i < x_num; ++i) {
sections.push_back(x[i].dims()[axis_value]);
}
std::vector<Tensor> x_grad_tmp =
split<T>(out_grad, phi::IntArray(sections), axis);
for (int i = 0; i < x_num; ++i) {
set_output<T>(x_grad_tmp.at(i), x_grad.at(i));
}
}
template <typename T> template <typename T>
void multiply_grad(const Tensor& x, void multiply_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
......
...@@ -33,6 +33,14 @@ Tensor full<Tensor>(const IntArray& shape, ...@@ -33,6 +33,14 @@ Tensor full<Tensor>(const IntArray& shape,
VLOG(4) << "Eager Prim API full_ad_func call"; VLOG(4) << "Eager Prim API full_ad_func call";
return ::full_ad_func(shape, value, dtype, place); return ::full_ad_func(shape, value, dtype, place);
} }
template <>
std::vector<Tensor> split<Tensor>(const Tensor& x,
const IntArray& sections,
const Scalar& axis) {
VLOG(4) << "Eager Prim API split_ad_func call";
return ::split_ad_func(x, sections, axis);
}
template <> template <>
Tensor cast<Tensor>(const Tensor& x, DataType dtype) { Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype); return ::cast_ad_func(x, dtype);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <vector>
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -36,7 +37,14 @@ Tensor full(const IntArray& shape, ...@@ -36,7 +37,14 @@ Tensor full(const IntArray& shape,
const Scalar& value, const Scalar& value,
DataType dtype = DataType::FLOAT32, DataType dtype = DataType::FLOAT32,
const Place& place = CPUPlace()); const Place& place = CPUPlace());
template <typename T>
std::vector<Tensor> split(const Tensor& x,
const IntArray& sections,
const Scalar& axis);
template <typename T> template <typename T>
Tensor cast(const Tensor& x, DataType dtype); Tensor cast(const Tensor& x, DataType dtype);
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -120,6 +120,33 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -120,6 +120,33 @@ Tensor full<DescTensor>(const IntArray& shape,
op->InferShape(*block); op->InferShape(*block);
return out; return out;
} }
template <>
std::vector<Tensor> split<DescTensor>(const Tensor& x,
const IntArray& sections,
const Scalar& axis) {
int elem_num = sections.size();
std::vector<std::string> outs_name;
std::vector<Tensor> outs;
for (int i = 0; i < elem_num; ++i) {
Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
std::string out_name =
std::static_pointer_cast<prim::DescTensor>(out.impl())->Name();
outs_name.push_back(std::move(out_name));
outs.push_back(out);
}
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("split");
op->SetAttr("sections", sections.GetData());
op->SetAttr("axis", axis.to<int>());
op->SetOutput("Out", outs_name);
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return outs;
}
template <> template <>
Tensor cast<DescTensor>(const Tensor& x, DataType dtype) { Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place()); Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place());
......
...@@ -207,6 +207,7 @@ ...@@ -207,6 +207,7 @@
param : [x] param : [x]
kernel : kernel :
func : concat_grad func : concat_grad
composite : concat_grad(x, out_grad, axis, x_grad)
no_need_buffer : x no_need_buffer : x
backward : concat_double_grad backward : concat_double_grad
......
...@@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS ...@@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS
test_activation_op test_activation_op
test_full_like_op test_full_like_op
test_fill_any_like_op test_fill_any_like_op
test_concat_op
test_top_k_v2_op) test_top_k_v2_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
......
...@@ -32,6 +32,8 @@ class TestConcatOp(OpTest): ...@@ -32,6 +32,8 @@ class TestConcatOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "concat" self.op_type = "concat"
self.python_api = paddle.concat self.python_api = paddle.concat
self.prim_op_type = "prim"
self.enable_cinn = False
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.init_test_data() self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
...@@ -61,13 +63,13 @@ class TestConcatOp(OpTest): ...@@ -61,13 +63,13 @@ class TestConcatOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.uint16: if self.dtype == np.uint16:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out') self.check_grad_with_place(place, ['x0'], 'Out', check_prim=True)
self.check_grad_with_place(place, ['x1'], 'Out') self.check_grad_with_place(place, ['x1'], 'Out', check_prim=True)
self.check_grad_with_place(place, ['x2'], 'Out') self.check_grad_with_place(place, ['x2'], 'Out', check_prim=True)
else: else:
self.check_grad(['x0'], 'Out', check_eager=True) self.check_grad(['x0'], 'Out', check_eager=True, check_prim=True)
self.check_grad(['x1'], 'Out', check_eager=True) self.check_grad(['x1'], 'Out', check_eager=True, check_prim=True)
self.check_grad(['x2'], 'Out', check_eager=True) self.check_grad(['x2'], 'Out', check_eager=True, check_prim=True)
def init_test_data(self): def init_test_data(self):
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -133,6 +135,8 @@ class TestConcatOp6(TestConcatOp): ...@@ -133,6 +135,8 @@ class TestConcatOp6(TestConcatOp):
self.op_type = "concat" self.op_type = "concat"
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.python_api = paddle.concat self.python_api = paddle.concat
self.prim_op_type = "prim"
self.enable_cinn = False
self.init_test_data() self.init_test_data()
self.lod = [[20, 80]] self.lod = [[20, 80]]
self.out_lod = [[20, 80, 20, 80, 20, 80]] self.out_lod = [[20, 80, 20, 80, 20, 80]]
...@@ -167,6 +171,54 @@ class TestConcatOp6(TestConcatOp): ...@@ -167,6 +171,54 @@ class TestConcatOp6(TestConcatOp):
self.axis = 0 self.axis = 0
class TestConcatOp7(TestConcatOp):
def setUp(self):
self.op_type = "concat"
self.python_api = paddle.concat
self.prim_op_type = "prim"
self.enable_cinn = True
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
else:
self.actual_axis = self.axis
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.actual_axis
)
}
def get_dtype(self):
return "float64"
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['x0'], 'Out', check_eager=True, check_prim=True)
self.check_grad(['x1'], 'Out', check_eager=True, check_prim=True)
self.check_grad(['x2'], 'Out', check_eager=True, check_prim=True)
def init_test_data(self):
if self.dtype == np.uint16:
x0 = np.random.random((5, 1, 4, 5)).astype(np.float32)
self.x0 = convert_float_to_uint16(x0)
x1 = np.random.random((5, 2, 4, 5)).astype(np.float32)
self.x1 = convert_float_to_uint16(x1)
x2 = np.random.random((5, 3, 4, 5)).astype(np.float32)
self.x2 = convert_float_to_uint16(x2)
else:
self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype)
self.axis = 1
def create_test_AxisTensor(parent): def create_test_AxisTensor(parent):
class TestConcatAxisTensor(parent): class TestConcatAxisTensor(parent):
def setUp(self): def setUp(self):
...@@ -175,6 +227,8 @@ def create_test_AxisTensor(parent): ...@@ -175,6 +227,8 @@ def create_test_AxisTensor(parent):
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
self.init_test_data() self.init_test_data()
self.prim_op_type = "prim"
self.enable_cinn = False
self.inputs = { self.inputs = {
'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)], 'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)],
'AxisTensor': np.array([self.axis]).astype("int32"), 'AxisTensor': np.array([self.axis]).astype("int32"),
......
...@@ -1136,6 +1136,7 @@ def concat(x, axis=0, name=None): ...@@ -1136,6 +1136,7 @@ def concat(x, axis=0, name=None):
'int64', 'int64',
'int8', 'int8',
'unit8', 'unit8',
'uint16',
], ],
'concat', 'concat',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册