未验证 提交 ca4155c8 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Extension] Fix custom double_grad backward=None (#49224)

* fix custom double_grad backward=None

* fix custom_relu.cu bug && polish testcase of double_grad

* remove old dynamic graph test
上级 941444aa
......@@ -410,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()(
for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[1][0].find(i) != map[1][0].end()) {
int grad_output_idx = map[1][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[1][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
}
tmp_outs[map[1][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
......
......@@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
data_t* ddout_data,
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) {
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
......
......@@ -148,16 +148,23 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False
dx = paddle.grad(
outputs=[out], inputs=[t], create_graph=True, retain_graph=True
outputs=out,
inputs=t,
grad_outputs=paddle.ones_like(t),
create_graph=True,
retain_graph=True,
)
dx[0].backward()
ddout = paddle.grad(
outputs=dx[0],
inputs=out.grad,
grad_outputs=paddle.ones_like(t),
create_graph=False,
)
assert dx[0].grad is not None
return dx[0].numpy(), dx[0].grad.numpy()
assert ddout[0].numpy() is not None
return dx[0].numpy(), ddout[0].numpy()
class TestNewCustomOpSetUpInstall(unittest.TestCase):
......@@ -346,7 +353,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
)
paddle.disable_static()
def test_func_double_grad_dynamic(self):
def test_double_grad_dynamic(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for device in self.devices:
for dtype in self.dtypes:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册