From c40c16a964013dd62bb9c97977f860ff53657575 Mon Sep 17 00:00:00 2001 From: xiemoyuan <71377852+xiemoyuan@users.noreply.github.com> Date: Mon, 26 Apr 2021 15:19:34 +0800 Subject: [PATCH] Modified the return value of tensor.grad from numpy to tensor. (#32142) * Modified the return value of tensor.grad from numpy as tensor. * Modify unittests. * fixed bugs. * Add warning info for x.grad * fixed unittests which used x.grad * fixed bug. --- .../fluid/dygraph/varbase_patch_methods.py | 35 +++++++++++-- .../tests/custom_op/test_custom_concat.py | 2 +- .../fluid/tests/custom_op/test_custom_conj.py | 5 +- .../custom_op/test_custom_relu_op_setup.py | 5 +- .../parallel_dygraph_gradient_check.py | 3 +- .../fluid/tests/unittests/test_base_layer.py | 6 ++- .../tests/unittests/test_custom_grad_input.py | 9 ++-- .../tests/unittests/test_imperative_basic.py | 50 +++++++++++-------- .../fluid/tests/unittests/test_inplace.py | 8 +-- .../fluid/tests/unittests/test_lookahead.py | 3 +- .../fluid/tests/unittests/test_pylayer_op.py | 6 ++- .../unittests/test_tensor_register_hook.py | 39 ++++++++------- .../fluid/tests/unittests/test_var_base.py | 9 ++-- 13 files changed, 114 insertions(+), 66 deletions(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 11bc150b281..dbc2b24aeea 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -26,6 +26,7 @@ from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE +import paddle.utils.deprecated as deprecated class TensorHookRemoveHelper(object): @@ -238,8 +239,16 @@ def monkey_patch_varbase(): "Variable.backward() is only available in DyGraph mode") @framework.dygraph_only + @deprecated( + since="2.1.0", + reason="Please use x.grad, which returns the tensor value of the gradient." + ) def gradient(self): """ + .. warning:: + This API will be deprecated in the future, it is recommended to use + :code:`x.grad` which returns the tensor value of the gradient. + Get the Gradient of Current Tensor. Returns: @@ -253,7 +262,7 @@ def monkey_patch_varbase(): x = paddle.to_tensor(5., stop_gradient=False) y = paddle.pow(x, 4.0) y.backward() - print("grad of x: {}".format(x.grad)) + print("grad of x: {}".format(x.gradient())) # [500.] """ @@ -337,10 +346,28 @@ def monkey_patch_varbase(): @property def grad(self): """ - The alias of gradient(). - """ + .. warning:: + This API will return the tensor value of the gradient. If you want + to get the numpy value of the gradient, you can use :code:`x.grad.numpy()`. + + Get the Gradient of Current Tensor. + + Returns: + Tensor: the gradient of current Tensor + + Examples: + .. code-block:: python + + import paddle - return self.gradient() + x = paddle.to_tensor(5., stop_gradient=False) + y = paddle.pow(x, 4.0) + y.backward() + print("grad of x: {}".format(x.grad)) + # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [500.]) + + """ + return self._grad_ivar() def clear_grad(self): """ diff --git a/python/paddle/fluid/tests/custom_op/test_custom_concat.py b/python/paddle/fluid/tests/custom_op/test_custom_concat.py index ea41126c1c4..d796c3b5fbd 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_concat.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_concat.py @@ -58,7 +58,7 @@ def concat_dynamic(func, dtype, np_inputs, axis_v, with_attr=False): out = func(inputs, axis) out.stop_gradient = False out.backward() - grad_inputs = [x.grad for x in inputs] + grad_inputs = [x.grad.numpy() for x in inputs] return out.numpy(), grad_inputs diff --git a/python/paddle/fluid/tests/custom_op/test_custom_conj.py b/python/paddle/fluid/tests/custom_op/test_custom_conj.py index 3a8f79a06fc..a8e40198803 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_conj.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_conj.py @@ -63,7 +63,10 @@ def conj_dynamic(func, dtype, np_input): sum_out.real().backward() else: sum_out.backward() - return out.numpy(), x.grad + if x.grad is None: + return out.numpy(), x.grad + else: + return out.numpy(), x.grad.numpy() def conj_static(func, shape, dtype, np_input): diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 642e93ebcb8..0af0aa16466 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -34,7 +34,10 @@ def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): out.backward() - return out.numpy(), t.grad + if t.grad is None: + return out.numpy(), t.grad + else: + return out.numpy(), t.grad.numpy() def custom_relu_static(func, diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py index 0d2631fa108..70023522409 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py @@ -110,7 +110,8 @@ class TestDistTraning(unittest.TestCase): def check_acc(self, grad, grad_sum, acc_grad): if grad is not None: - grad_sum = grad_sum + grad + grad_sum = grad_sum + grad.numpy() + acc_grad = acc_grad.numpy() if acc_grad is not None else None np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) return grad_sum diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index e6e15575f2c..27c8869b21d 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -349,7 +349,8 @@ class TestLayerTo(unittest.TestCase): paddle.fluid.core.VarDesc.VarType.FP64) self.assertEqual(self.linear.buf_name.dtype, paddle.fluid.core.VarDesc.VarType.FP64) - self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad)) + self.assertTrue( + np.allclose(self.linear.weight.grad.numpy(), self.new_grad)) self.assertTrue(self.linear.weight._grad_ivar().dtype, paddle.fluid.core.VarDesc.VarType.FP64) @@ -358,7 +359,8 @@ class TestLayerTo(unittest.TestCase): paddle.fluid.core.VarDesc.VarType.FP64) self.assertEqual(self.linear.buf_name.dtype, paddle.fluid.core.VarDesc.VarType.FP64) - self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad)) + self.assertTrue( + np.allclose(self.linear.weight.grad.numpy(), self.new_grad)) self.assertTrue(self.linear.weight._grad_ivar().dtype, paddle.fluid.core.VarDesc.VarType.FP64) diff --git a/python/paddle/fluid/tests/unittests/test_custom_grad_input.py b/python/paddle/fluid/tests/unittests/test_custom_grad_input.py index a7472e7ffd7..623b7e68b3f 100644 --- a/python/paddle/fluid/tests/unittests/test_custom_grad_input.py +++ b/python/paddle/fluid/tests/unittests/test_custom_grad_input.py @@ -46,7 +46,7 @@ class TestTensorBackward(unittest.TestCase): x_grad = np.matmul(grad, y.T) - self.assertTrue(np.allclose(x_grad, x_tensor.grad)) + self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) class TestBackwardAPI(unittest.TestCase): @@ -75,7 +75,8 @@ class TestBackwardAPI(unittest.TestCase): x_grad = np.matmul(grad, y.T) - self.assertTrue(np.allclose(x_grad * 2, x_tensor.grad)) + self.assertTrue( + np.allclose(x_grad * 2, x_tensor.grad.numpy())) def test_backward_single_tensor(self): for dtype in self._dtypes: @@ -94,7 +95,7 @@ class TestBackwardAPI(unittest.TestCase): x_grad = np.matmul(grad, y.T) - self.assertTrue(np.allclose(x_grad, x_tensor.grad)) + self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) def test_backward_none_grad_tensor(self): for dtype in self._dtypes: @@ -112,7 +113,7 @@ class TestBackwardAPI(unittest.TestCase): x_grad = np.matmul(grad, y.T) - self.assertTrue(np.allclose(x_grad, x_tensor.grad)) + self.assertTrue(np.allclose(x_grad, x_tensor.grad.numpy())) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 9dae36c3c22..1cdb57c540a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -506,15 +506,15 @@ class TestImperative(unittest.TestCase): for i in range(10): y = paddle.pow(x, 4.0) y.backward() - self.assertEqual(x.grad, (i + 1) * 500) + self.assertEqual(x.grad.numpy(), (i + 1) * 500) x.clear_gradient() - self.assertEqual(x.grad, 0.) + self.assertEqual(x.grad.numpy(), 0.) for i in range(10): y = paddle.pow(x, 4.0) y.backward() - self.assertEqual(x.grad, (i + 1) * 500) + self.assertEqual(x.grad.numpy(), (i + 1) * 500) x.clear_grad() - self.assertEqual(x.grad, 0.) + self.assertEqual(x.grad.numpy(), 0.) def test_simple_net(sort_sum_gradient): fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) @@ -527,9 +527,9 @@ class TestImperative(unittest.TestCase): loss2 = x * z loss1.backward(retain_graph=True) loss2.backward(retain_graph=True) - self.assertTrue(np.array_equal(x.grad, [23.])) - self.assertTrue(np.array_equal(y.grad, [25.])) - self.assertTrue(np.array_equal(z.grad, [5.])) + self.assertTrue(np.array_equal(x.grad.numpy(), [23.])) + self.assertTrue(np.array_equal(y.grad.numpy(), [25.])) + self.assertTrue(np.array_equal(z.grad.numpy(), [5.])) x.clear_grad() y.clear_grad() z.clear_grad() @@ -542,13 +542,13 @@ class TestImperative(unittest.TestCase): loss = fun(x, y, z) loss.backward(retain_graph=True) # x.grad = 2*x*y + z + 2*y = 27 - self.assertTrue(np.array_equal(x.grad, [27])) + self.assertTrue(np.array_equal(x.grad.numpy(), [27])) loss.backward(retain_graph=True) - self.assertTrue(np.array_equal(x.grad, [54])) + self.assertTrue(np.array_equal(x.grad.numpy(), [54])) loss.backward() - self.assertTrue(np.array_equal(x.grad, [81])) + self.assertTrue(np.array_equal(x.grad.numpy(), [81])) with self.assertRaises(RuntimeError): loss.backward() @@ -558,8 +558,8 @@ class TestImperative(unittest.TestCase): dx = paddle.grad([loss1], x, create_graph=True)[0] loss = loss1 + loss2 + dx loss.backward() - self.assertTrue(np.array_equal(dx.grad, [1])) - self.assertTrue(np.array_equal(x.grad, [108])) + self.assertTrue(np.array_equal(dx.grad.numpy(), [1])) + self.assertTrue(np.array_equal(x.grad.numpy(), [108])) def test_mlp(sort_sum_gradient): fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) @@ -579,28 +579,34 @@ class TestImperative(unittest.TestCase): detach_x = x.detach() clear_loss = mlp2(detach_x) clear_loss.backward() - expected_weight1_grad = expected_weight1_grad + mlp2._linear1.weight.grad - expected_bias1_grad = expected_bias1_grad + mlp2._linear1.bias.grad - expected_weight2_grad = expected_weight2_grad + mlp2._linear2.weight.grad - expected_bias2_grad = expected_bias2_grad + mlp2._linear2.bias.grad + expected_weight1_grad = ( + expected_weight1_grad + mlp2._linear1.weight.grad.numpy()) + expected_bias1_grad = ( + expected_bias1_grad + mlp2._linear1.bias.grad.numpy()) + expected_weight2_grad = ( + expected_weight2_grad + mlp2._linear2.weight.grad.numpy()) + expected_bias2_grad = ( + expected_bias2_grad + mlp2._linear2.bias.grad.numpy()) loss = mlp1(x) loss.backward() - self.assertTrue(np.array_equal(loss.grad, [1])) + self.assertTrue(np.array_equal(loss.grad.numpy(), [1])) self.assertTrue( - np.allclose(mlp1._linear1.weight.grad, + np.allclose(mlp1._linear1.weight.grad.numpy(), expected_weight1_grad)) self.assertTrue( - np.allclose(mlp1._linear1.bias.grad, expected_bias1_grad)) + np.allclose(mlp1._linear1.bias.grad.numpy(), + expected_bias1_grad)) self.assertTrue( - np.allclose(mlp1._linear2.weight.grad, + np.allclose(mlp1._linear2.weight.grad.numpy(), expected_weight2_grad)) self.assertTrue( - np.allclose(mlp1._linear2.bias.grad, expected_bias2_grad)) + np.allclose(mlp1._linear2.bias.grad.numpy(), + expected_bias2_grad)) mlp2.clear_gradients() - self.assertTrue(np.array_equal(clear_loss.grad, [1])) + self.assertTrue(np.array_equal(clear_loss.grad.numpy(), [1])) if ((batch_id + 1) % 10) == 0: mlp1.clear_gradients() expected_weight1_grad = 0. diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 2c6507c486e..7b9becacd82 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -177,7 +177,7 @@ class TestDygraphInplace(unittest.TestCase): var_d = var_c**2 loss = var_d.sum() loss.backward() - grad_var_a_inplace = var_a.grad + grad_var_a_inplace = var_a.grad.numpy() with paddle.fluid.dygraph.guard(): var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) @@ -188,7 +188,7 @@ class TestDygraphInplace(unittest.TestCase): var_d = var_c**2 loss = var_d.sum() loss.backward() - grad_var_a = var_a.grad + grad_var_a = var_a.grad.numpy() self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a)) @@ -209,7 +209,7 @@ class TestDygraphInplace(unittest.TestCase): loss = var_d.sum() loss.backward() - grad_var_a_inplace = var_a.grad + grad_var_a_inplace = var_a.grad.numpy() with paddle.fluid.dygraph.guard(): var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) @@ -224,7 +224,7 @@ class TestDygraphInplace(unittest.TestCase): loss = var_d.sum() loss.backward() - grad_var_a = var_a.grad + grad_var_a = var_a.grad.numpy() self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a)) diff --git a/python/paddle/fluid/tests/unittests/test_lookahead.py b/python/paddle/fluid/tests/unittests/test_lookahead.py index 98349be93db..a4b5e6d0d95 100644 --- a/python/paddle/fluid/tests/unittests/test_lookahead.py +++ b/python/paddle/fluid/tests/unittests/test_lookahead.py @@ -110,7 +110,8 @@ class TestLookAhead(unittest.TestCase): out = layer(image) loss = loss_fn(out, label) loss.backward() - fast_param = layer.bias.numpy() - SGD_LR * layer.bias.grad + fast_param = ( + layer.bias.numpy() - SGD_LR * layer.bias.grad.numpy()) opt.step() if idx == 1: slow_param = fast_param diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index f00db0b3693..565ed992bc5 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -50,7 +50,8 @@ class TestPyLayer(unittest.TestCase): z2 = paddle.tanh(input2) + paddle.tanh(input2) z2.mean().backward() - self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10) + self.assertTrue( + np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10) def test_simple_pylayer_return_none_with_no_grad(self): class tanh(PyLayer): @@ -110,7 +111,8 @@ class TestPyLayer(unittest.TestCase): z2 = paddle.tanh(input2) z2.mean().backward() - self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10) + self.assertTrue( + np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10) def test_pylayer_dtype(self): class tanh(PyLayer): diff --git a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py index 50b00ab34fd..a03e4ae4bd9 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py @@ -75,15 +75,15 @@ class TestTensorRegisterHook(unittest.TestCase): o.backward() # z.grad is not affected - self.assertTrue(np.array_equal(z.grad, w.numpy())) + self.assertTrue(np.array_equal(z.grad.numpy(), w.numpy())) # w.grad is not changed by hook - self.assertTrue(np.array_equal(w.grad, z.numpy())) + self.assertTrue(np.array_equal(w.grad.numpy(), z.numpy())) # x.grad and y.grad are changed if run hook self.assertTrue( - np.array_equal(x.grad, + np.array_equal(x.grad.numpy(), z.numpy() * 2 if not removed else z.numpy())) self.assertTrue( - np.array_equal(y.grad, + np.array_equal(y.grad.numpy(), z.numpy() * 2 if not removed else z.numpy())) def run_print_hook_for_interior_var(print_hook, removed=False): @@ -111,10 +111,10 @@ class TestTensorRegisterHook(unittest.TestCase): o.backward() # all grads are not affected - self.assertTrue(np.array_equal(z.grad, w.numpy())) - self.assertTrue(np.array_equal(w.grad, z.numpy())) - self.assertTrue(np.array_equal(x.grad, z.numpy())) - self.assertTrue(np.array_equal(y.grad, z.numpy())) + self.assertTrue(np.array_equal(z.grad.numpy(), w.numpy())) + self.assertTrue(np.array_equal(w.grad.numpy(), z.numpy())) + self.assertTrue(np.array_equal(x.grad.numpy(), z.numpy())) + self.assertTrue(np.array_equal(y.grad.numpy(), z.numpy())) def double_hook(grad): grad = grad * 2 @@ -165,12 +165,12 @@ class TestTensorRegisterHook(unittest.TestCase): o.backward() # z.grad, w.grad, x.grad is not affected - self.assertTrue(np.array_equal(z.grad, w.numpy())) - self.assertTrue(np.array_equal(w.grad, z.numpy())) - self.assertTrue(np.array_equal(x.grad, z.numpy())) + self.assertTrue(np.array_equal(z.grad.numpy(), w.numpy())) + self.assertTrue(np.array_equal(w.grad.numpy(), z.numpy())) + self.assertTrue(np.array_equal(x.grad.numpy(), z.numpy())) # y.grad are changed if run hook self.assertTrue( - np.array_equal(y.grad, + np.array_equal(y.grad.numpy(), z.numpy() * 2 if not removed else z.numpy())) # register hook @@ -217,14 +217,14 @@ class TestTensorRegisterHook(unittest.TestCase): base_grad = np.array([5., 9., 13., 19.]) # x.grad is not changed - self.assertTrue(np.array_equal(x.grad, base_grad)) + self.assertTrue(np.array_equal(x.grad.numpy(), base_grad)) # b.grad is changed by x.hook self.assertTrue( - np.array_equal(b.grad, base_grad * 2 + np.array_equal(b.grad.numpy(), base_grad * 2 if not removed else base_grad)) # a.grad is changed by x.hook and a.hook self.assertTrue( - np.array_equal(a.grad, base_grad * 4 + np.array_equal(a.grad.numpy(), base_grad * 4 if not removed else base_grad)) # register hook @@ -265,7 +265,7 @@ class TestTensorRegisterHook(unittest.TestCase): base_grad = np.array([5., 9., 13., 19.]) # x.grad is changed by x.hook self.assertTrue( - np.array_equal(x.grad, base_grad * 2 + np.array_equal(x.grad.numpy(), base_grad * 2 if not removed else base_grad)) # register hook @@ -294,7 +294,8 @@ class TestTensorRegisterHook(unittest.TestCase): loss = loss_fn(out, label) loss.backward() - return ret1.grad, net.linear1.weight.grad, net.linear1.bias.grad + return (ret1.grad.numpy(), net.linear1.weight.grad.numpy(), + net.linear1.bias.grad.numpy()) data = np.random.uniform( size=[self.batch_size, self.in_size]).astype('float32') @@ -355,7 +356,7 @@ class TestTensorRegisterHook(unittest.TestCase): o.backward() - return z.numpy(), w.grad, x.grad, y.grad + return z.numpy(), w.grad.numpy(), x.grad.numpy(), y.grad.numpy() def double_hook(grad): return grad * 2 @@ -428,7 +429,7 @@ class TestTensorRegisterHook(unittest.TestCase): # after changed by hook: 8.0 z.backward() - self.assertTrue(np.array_equal(x.grad, np.array([8.]))) + self.assertTrue(np.array_equal(x.grad.numpy(), np.array([8.]))) def test_remove_one_hook_multiple_times(self): for device in self.devices: diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 7901df79171..a65308c84e7 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -65,7 +65,8 @@ class TestVarBase(unittest.TestCase): y = clone_x**2 y.backward() self.assertTrue( - np.array_equal(x.grad, np.array([2.4]).astype('float32'))) + np.array_equal(x.grad.numpy(), + np.array([2.4]).astype('float32'))) y = x.cpu() self.assertEqual(y.place.__repr__(), "CPUPlace") if core.is_compiled_with_cuda(): @@ -260,14 +261,14 @@ class TestVarBase(unittest.TestCase): y = x**2 y.backward() - self.assertTrue(np.array_equal(x.grad, [20.0])) + self.assertTrue(np.array_equal(x.grad.numpy(), [20.0])) self.assertEqual(detach_x.grad, None) detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad z = 3 * detach_x**2 z.backward() - self.assertTrue(np.array_equal(x.grad, [20.0])) - self.assertTrue(np.array_equal(detach_x.grad, [60.0])) + self.assertTrue(np.array_equal(x.grad.numpy(), [20.0])) + self.assertTrue(np.array_equal(detach_x.grad.numpy(), [60.0])) # Due to sharing of data with origin Tensor, There are some unsafe operations: with self.assertRaises(RuntimeError): -- GitLab