未验证 提交 d83d89ed 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Op] Polish customOP optional implementation for PaddleDoc (#52368)

上级 fcd77346
...@@ -48,18 +48,11 @@ std::vector<paddle::Tensor> AddForward( ...@@ -48,18 +48,11 @@ std::vector<paddle::Tensor> AddForward(
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place()); paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddForward", ([&] {
if (y) { if (y) {
add_two_pointers<data_t>(x.data<data_t>(), out = x + y.get();
y->data<data_t>(),
out.data<data_t>(),
x.size());
} else { } else {
add_two_pointers<data_t>( out = x + x;
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
} }
}));
return {out}; return {out};
} }
...@@ -93,19 +86,13 @@ std::vector<paddle::Tensor> AddBackward( ...@@ -93,19 +86,13 @@ std::vector<paddle::Tensor> AddBackward(
const paddle::optional<paddle::Tensor>& y, const paddle::optional<paddle::Tensor>& y,
const paddle::Tensor& out_grad) { // NOLINT const paddle::Tensor& out_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place()); paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES( if (y) {
out_grad.type(), "AddBackward", ([&] { x_grad = out_grad;
add_one_pointer<data_t>( } else {
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size()); x_grad = out_grad + out_grad;
if (!y) {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
} }
}));
return {x_grad}; return {x_grad};
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册