未验证 提交 bb9eb20f 编写于 作者: C ccrrong 提交者: GitHub

add split and split_with_num composite rule (#51341)

* add split_with_num composite rule

* add split_with_num composite rule

* add split composite rule

* update

* update test

* update test

* delete split_with_num_grad
上级 a5ebe6ae
......@@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
......@@ -204,6 +207,36 @@ Example:
}
};
class SplitCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::optional<std::vector<paddle::Tensor>> tensor_sections =
this->GetOptionalMultiForwardInput("SectionsTensorList");
paddle::optional<paddle::Tensor> tensor_axis =
this->GetOptionalSingleForwardInput("AxisTensor");
int axis = static_cast<int>(this->Attr<int>("axis"));
std::vector<int> sections =
static_cast<std::vector<int>>(this->Attr<std::vector<int>>("sections"));
paddle::Tensor input_grad = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
std::vector<paddle::Tensor> out_grad = this->GetMultiOutputGrad("Out");
if (tensor_axis.is_initialized() || tensor_sections.is_initialized()) {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic index or sections from tensor for split "
"composite grad for now. "));
} else {
VLOG(6) << "Runing split_grad composite func";
prim::split_grad<prim::DescTensor>(out_grad, axis, dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -212,5 +245,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split,
ops::SplitOp,
ops::SplitOpMaker,
ops::SplitCompositeGradOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
......@@ -859,6 +859,16 @@ void cumsum_grad(const Tensor& x,
}
}
template <typename T>
void split_grad(const std::vector<Tensor>& out_grad,
const Scalar& axis,
Tensor* x_grad) {
if (x_grad) {
auto grad = concat<T>(out_grad, axis);
set_output<T>(grad, x_grad);
}
}
template <typename T>
void topk_grad(const Tensor& x,
const Tensor& indices,
......
......@@ -41,6 +41,7 @@ std::vector<Tensor> split<Tensor>(const Tensor& x,
VLOG(4) << "Eager Prim API split_ad_func call";
return ::split_ad_func(x, sections, axis);
}
template <>
Tensor cast<Tensor>(const Tensor& x, DataType dtype) {
return ::cast_ad_func(x, dtype);
......
......@@ -1160,12 +1160,14 @@
args : (Tensor[] out_grad, Scalar axis = -1)
output : Tensor(x_grad)
invoke : concat( out_grad, axis)
composite : split_grad(out_grad, axis, x_grad)
- backward_op : split_with_num_grad
forward : split_with_num (Tensor x, int num, Scalar axis) -> Tensor[](out)
args : (Tensor[] out_grad, Scalar axis = -1)
output : Tensor(x_grad)
invoke : concat( out_grad, axis)
composite : split_grad(out_grad, axis, x_grad)
- backward_op : squared_l2_norm_grad
forward : squared_l2_norm(Tensor x) -> Tensor(out)
......
......@@ -27,9 +27,11 @@ class TestSplitOp(OpTest):
self.python_api = paddle.split
self.python_out_sig = ['out0', 'out1', 'out2']
self._set_op_type()
self.prim_op_type = "prim"
self.dtype = self.get_dtype()
axis = 1
if self.dtype == np.uint16:
self.enable_cinn = False
x = np.random.random((4, 5, 6)).astype(np.float32)
out = np.split(x, [2, 3], axis)
self.inputs = {'X': convert_float_to_uint16(x)}
......@@ -58,7 +60,7 @@ class TestSplitOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'])
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True)
# test with attr(num)
......@@ -67,6 +69,7 @@ class TestSplitOp_2(OpTest):
self.python_api = paddle.split
self.python_out_sig = ['out0', 'out1', 'out2']
self._set_op_type()
self.prim_op_type = "prim"
self.dtype = self.get_dtype()
self.init_data()
self.inputs = {'X': self.x}
......@@ -96,7 +99,7 @@ class TestSplitOp_2(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'])
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True)
# attr(axis) is Tensor
......@@ -189,6 +192,8 @@ class TestSplitOp_unk_section(OpTest):
self.python_api = paddle.split
self.python_out_sig = ['out0', 'out1', 'out2']
self._set_op_type()
self.prim_op_type = "prim"
self.enable_cinn = False
self.dtype = self.get_dtype()
self.init_data()
self.inputs = {'X': self.x}
......@@ -218,7 +223,7 @@ class TestSplitOp_unk_section(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'])
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_prim=True)
class TestSplitByrefOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册