未验证 提交 d83d89ed 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Op] Polish customOP optional implementation for PaddleDoc (#52368)

上级 fcd77346
......@@ -48,18 +48,11 @@ std::vector<paddle::Tensor> AddForward(
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddForward", ([&] {
if (y) {
add_two_pointers<data_t>(x.data<data_t>(),
y->data<data_t>(),
out.data<data_t>(),
x.size());
} else {
add_two_pointers<data_t>(
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
}
}));
if (y) {
out = x + y.get();
} else {
out = x + x;
}
return {out};
}
......@@ -93,19 +86,13 @@ std::vector<paddle::Tensor> AddBackward(
const paddle::optional<paddle::Tensor>& y,
const paddle::Tensor& out_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddBackward", ([&] {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
if (!y) {
add_one_pointer<data_t>(
out_grad.data<data_t>(), x_grad.data<data_t>(), out_grad.size());
}
}));
if (y) {
x_grad = out_grad;
} else {
x_grad = out_grad + out_grad;
}
return {x_grad};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册