diff --git a/test/custom_op/custom_optional.cc b/test/custom_op/custom_optional.cc index 498c56ce5071ef159d415d08a848735450ea623b..52c8e989d0e692c0c93162c09a323fda79abe161 100644 --- a/test/custom_op/custom_optional.cc +++ b/test/custom_op/custom_optional.cc @@ -48,18 +48,11 @@ std::vector AddForward( PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place()); - PD_DISPATCH_FLOATING_TYPES( - x.type(), "AddForward", ([&] { - if (y) { - add_two_pointers(x.data(), - y->data(), - out.data(), - x.size()); - } else { - add_two_pointers( - x.data(), x.data(), out.data(), x.size()); - } - })); + if (y) { + out = x + y.get(); + } else { + out = x + x; + } return {out}; } @@ -93,19 +86,13 @@ std::vector AddBackward( const paddle::optional& y, const paddle::Tensor& out_grad) { // NOLINT PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); - paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place()); - PD_DISPATCH_FLOATING_TYPES( - out_grad.type(), "AddBackward", ([&] { - add_one_pointer( - out_grad.data(), x_grad.data(), out_grad.size()); - if (!y) { - add_one_pointer( - out_grad.data(), x_grad.data(), out_grad.size()); - } - })); - + if (y) { + x_grad = out_grad; + } else { + x_grad = out_grad + out_grad; + } return {x_grad}; }