diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 5c3c2fbe7e9c6000a205e8500d0c5cd248268ab6..57932ec4c1e693e0db01684c0a0ef6cdb5964f24 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -410,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()( for (size_t i = 0; i < OutputMeta().size(); i++) { if (map[1][0].find(i) != map[1][0].end()) { + int grad_output_idx = map[1][0][i]; VLOG(7) << "Insert grad outputs: " << i - << " with size: " << OutputMeta()[i].size() - << " to tmp_outputs: " << map[1][0][i]; - for (size_t j = 0; j < OutputMeta()[i].size(); j++) { - outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ - std::make_shared( - phi::DataType::UNDEFINED), - egr::Controller::Instance().GenerateUniqueName( - "custom_tmp_grad")); + << " with size: " << OutputMeta()[grad_output_idx].size() + << " to tmp_outputs: " << grad_output_idx; + for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) { + outs[grad_output_idx] + .emplace_back(/* init it incase of copy nullptr of shared_ptr */ + std::make_shared( + phi::DataType::UNDEFINED), + egr::Controller::Instance().GenerateUniqueName( + "custom_tmp_grad")); } - tmp_outs[map[1][0][i]] = outs[i]; + tmp_outs[grad_output_idx] = outs[grad_output_idx]; } } for (size_t i = 0; i < tmp_outs.size(); i++) { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index e791ea8cb7600eb78b54b80f8af6265261b1bc66..53b61b4bb6611ba63028bfaea4b236c0ac51b77b 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data, data_t* ddout_data, int64_t num) { int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) { + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast(0.) ? static_cast(1.) : static_cast(0.)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index c2cd953b47a4e5d8efcff2d8628365a7c9c16179..599edf09b7f1b1d1954a01d4ef1b8ea00f4ebea8 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -148,16 +148,23 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True): t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) out = func(t) if use_func else paddle.nn.functional.relu(t) - out.stop_gradient = False - dx = paddle.grad( - outputs=[out], inputs=[t], create_graph=True, retain_graph=True + outputs=out, + inputs=t, + grad_outputs=paddle.ones_like(t), + create_graph=True, + retain_graph=True, ) - dx[0].backward() + ddout = paddle.grad( + outputs=dx[0], + inputs=out.grad, + grad_outputs=paddle.ones_like(t), + create_graph=False, + ) - assert dx[0].grad is not None - return dx[0].numpy(), dx[0].grad.numpy() + assert ddout[0].numpy() is not None + return dx[0].numpy(), ddout[0].numpy() class TestNewCustomOpSetUpInstall(unittest.TestCase): @@ -346,7 +353,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ) paddle.disable_static() - def test_func_double_grad_dynamic(self): + def test_double_grad_dynamic(self): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) for device in self.devices: for dtype in self.dtypes: