diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 14dba664c41b3d7b138630c739bfe7b934d04e9f..048e4f2b428f27087a43e595b02d0f28b4ad32b1 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -192,6 +192,7 @@ OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name, size_t index) { break; case 2: name_ = name_ + "_grad_grad"; + break; default: PADDLE_THROW(phi::errors::InvalidArgument( "Not support index `%d` when construct OpMetaInfoBuilder, " diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index acaf7cb74280bed23b20feab2a96aa85a9bb5cea..4ff9adf4f8fecaf552938ac96e8fc13df9b37c44 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -17,6 +17,9 @@ #include "paddle/extension.h" +#define CHECK_CPU_INPUT(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, @@ -39,6 +42,17 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, } } +template +void relu_cpu_double_backward_kernel(const data_t* out_data, + const data_t* ddx_data, + data_t* ddout_data, + int64_t ddout_numel) { + for (int64_t i = 0; i < ddout_numel; ++i) { + ddout_data[i] = + ddx_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + std::vector relu_cpu_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); @@ -67,10 +81,31 @@ std::vector relu_cpu_backward(const paddle::Tensor& x, return {grad_x}; } +std::vector relu_cpu_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx) { + CHECK_CPU_INPUT(out); + CHECK_CPU_INPUT(ddx); + auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] { + relu_cpu_double_backward_kernel( + out.data(), + ddx.data(), + ddout.mutable_data(out.place()), + ddout.size()); + })); + + std::cout << "Debug info: run relu cpu double backward success." << std::endl; + + return {ddout}; +} + std::vector relu_cuda_forward(const paddle::Tensor& x); std::vector relu_cuda_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out); +std::vector relu_cuda_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx); std::vector ReluForward(const paddle::Tensor& x) { // TODO(chenweihang): Check Input @@ -96,6 +131,23 @@ std::vector ReluBackward(const paddle::Tensor& x, } } +std::vector ReluDoubleBackward(const paddle::Tensor& out, + const paddle::Tensor& ddx) { + if (out.place() == paddle::PlaceType::kCPU) { + return relu_cpu_double_backward(out, ddx); + } else if (out.place() == paddle::PlaceType::kGPU) { + return relu_cuda_double_backward(out, ddx); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector> ReluDoubleBackwardInferShape( + const std::vector& out_shape, + const std::vector& ddx_shape) { + return {out_shape}; +} + PD_BUILD_OP(custom_relu) .Inputs({"X"}) .Outputs({"Out"}) @@ -106,6 +158,12 @@ PD_BUILD_GRAD_OP(custom_relu) .Outputs({paddle::Grad("X")}) .SetKernelFn(PD_KERNEL(ReluBackward)); +PD_BUILD_DOUBLE_GRAD_OP(custom_relu) + .Inputs({"Out", paddle::Grad(paddle::Grad("X"))}) + .Outputs({paddle::Grad(paddle::Grad("Out"))}) + .SetKernelFn(PD_KERNEL(ReluDoubleBackward)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape)); + std::vector relu_cpu_backward_without_x( const paddle::Tensor& out, const paddle::Tensor& grad_out) { auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 4bb773cdaec21712f262bcb217710f6909efd20a..8b9693054d1c4e39c43c049c66000e0486aad98c 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -14,6 +14,9 @@ #include "paddle/extension.h" +#define CHECK_GPU_INPUT(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") + template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, @@ -36,6 +39,19 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, } } +template +__global__ void relu_cuda_double_backward_kernel(const data_t* out_data, + const data_t* ddx_data, + data_t* ddout_data, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) { + ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast(0.) + ? static_cast(1.) + : static_cast(0.)); + } +} + std::vector relu_cuda_forward(const paddle::Tensor& x) { auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); @@ -71,6 +87,30 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, return {grad_x}; } +std::vector relu_cuda_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx) { + CHECK_GPU_INPUT(out); + CHECK_GPU_INPUT(ddx); + auto ddout = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); + + int64_t numel = out.size(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_double_backward_kernel", ([&] { + relu_cuda_double_backward_kernel< + data_t><<>>( + out.data(), + ddx.data(), + ddout.mutable_data(out.place()), + numel); + })); + + std::cout << "Debug info: run relu gpu double backward success." << std::endl; + + return {ddout}; +} + std::vector relu_cuda_backward_without_x( const paddle::Tensor& out, const paddle::Tensor& grad_out) { auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_setup.py b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py index cbc4d17a4c72b9faa4312999afd84db50546b26b..c3a76d65351c4965f28aef062ddd02dcf04c06f8 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_setup.py +++ b/python/paddle/fluid/tests/custom_op/custom_relu_setup.py @@ -31,4 +31,5 @@ setup( ext_modules=Extension( # test for not specific name here. sources=sources, # test for multi ops include_dirs=paddle_includes, - extra_compile_args=extra_compile_args)) + extra_compile_args=extra_compile_args, + verbose=True)) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 7c61e11a18ecd2ebbcd87fae37a8ba0a39ad56d1..582b14c82b52d89bf61ad6eed2971358bb299048 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -138,6 +138,23 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): return predict_v +def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True): + paddle.set_device(device) + + t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) + + out = func(t) if use_func else paddle.nn.functional.relu(t) + out.stop_gradient = False + + dx = paddle.grad( + outputs=[out], inputs=[t], create_graph=True, retain_graph=True) + + dx[0].backward() + + assert dx[0].grad is not None + return dx[0].numpy(), dx[0].grad.numpy() + + class TestNewCustomOpSetUpInstall(unittest.TestCase): def setUp(self): cur_dir = os.path.dirname(os.path.abspath(__file__)) @@ -293,6 +310,25 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): predict, predict_infer)) paddle.disable_static() + def test_func_double_grad_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + if device == 'cpu' and dtype == 'float16': + continue + x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) + out, dx_grad = custom_relu_double_grad_dynamic( + self.custom_ops[0], device, dtype, x) + pd_out, pd_dx_grad = custom_relu_double_grad_dynamic( + self.custom_ops[0], device, dtype, x, False) + self.assertTrue( + np.array_equal(out, pd_out), + "custom op out: {},\n paddle api out: {}".format(out, + pd_out)) + self.assertTrue( + np.array_equal(dx_grad, pd_dx_grad), + "custom op dx grad: {},\n paddle api dx grad: {}".format( + dx_grad, pd_dx_grad)) + if __name__ == '__main__': unittest.main()