未验证 提交 7a245b7a 编写于 作者: Z zhulei 提交者: GitHub

[Rocm] fix test_var_base (#32639)

上级 6d3eb3d0
......@@ -84,7 +84,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
if (gcs_.count(place) == 0) {
std::unique_ptr<framework::GarbageCollector> gc;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::DefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place), 0));
......@@ -95,7 +95,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
"Please recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_cuda_pinned_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::CUDAPinnedGarbageCollector(
BOOST_GET_CONST(platform::CUDAPinnedPlace, place), 0));
......
......@@ -256,19 +256,21 @@ class TestVarBase(unittest.TestCase):
detach_x = x.detach()
self.assertTrue(detach_x.stop_gradient, True)
cmp_float = np.allclose if core.is_compiled_with_rocm(
) else np.array_equal
detach_x[:] = 10.0
self.assertTrue(np.array_equal(x.numpy(), [10.0]))
self.assertTrue(cmp_float(x.numpy(), [10.0]))
y = x**2
y.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), [20.0]))
self.assertTrue(cmp_float(x.grad.numpy(), [20.0]))
self.assertEqual(detach_x.grad, None)
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = 3 * detach_x**2
z.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), [20.0]))
self.assertTrue(np.array_equal(detach_x.grad.numpy(), [60.0]))
self.assertTrue(cmp_float(x.grad.numpy(), [20.0]))
self.assertTrue(cmp_float(detach_x.grad.numpy(), [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations:
with self.assertRaises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册