“52e8ebf246faa737a9553e1fe7df531a10234074”上不存在“develop/doc/dev/index_en.html”
未验证 提交 7a245b7a 编写于 作者: Z zhulei 提交者: GitHub

[Rocm] fix test_var_base (#32639)

上级 6d3eb3d0
...@@ -84,7 +84,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( ...@@ -84,7 +84,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
if (gcs_.count(place) == 0) { if (gcs_.count(place) == 0) {
std::unique_ptr<framework::GarbageCollector> gc; std::unique_ptr<framework::GarbageCollector> gc;
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::DefaultStreamGarbageCollector( gc.reset(new framework::DefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place), 0)); BOOST_GET_CONST(platform::CUDAPlace, place), 0));
...@@ -95,7 +95,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( ...@@ -95,7 +95,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
"Please recompile or reinstall Paddle with GPU support.")); "Please recompile or reinstall Paddle with GPU support."));
#endif #endif
} else if (platform::is_cuda_pinned_place(place)) { } else if (platform::is_cuda_pinned_place(place)) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::CUDAPinnedGarbageCollector( gc.reset(new framework::CUDAPinnedGarbageCollector(
BOOST_GET_CONST(platform::CUDAPinnedPlace, place), 0)); BOOST_GET_CONST(platform::CUDAPinnedPlace, place), 0));
......
...@@ -256,19 +256,21 @@ class TestVarBase(unittest.TestCase): ...@@ -256,19 +256,21 @@ class TestVarBase(unittest.TestCase):
detach_x = x.detach() detach_x = x.detach()
self.assertTrue(detach_x.stop_gradient, True) self.assertTrue(detach_x.stop_gradient, True)
cmp_float = np.allclose if core.is_compiled_with_rocm(
) else np.array_equal
detach_x[:] = 10.0 detach_x[:] = 10.0
self.assertTrue(np.array_equal(x.numpy(), [10.0])) self.assertTrue(cmp_float(x.numpy(), [10.0]))
y = x**2 y = x**2
y.backward() y.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), [20.0])) self.assertTrue(cmp_float(x.grad.numpy(), [20.0]))
self.assertEqual(detach_x.grad, None) self.assertEqual(detach_x.grad, None)
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = 3 * detach_x**2 z = 3 * detach_x**2
z.backward() z.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), [20.0])) self.assertTrue(cmp_float(x.grad.numpy(), [20.0]))
self.assertTrue(np.array_equal(detach_x.grad.numpy(), [60.0])) self.assertTrue(cmp_float(detach_x.grad.numpy(), [60.0]))
# Due to sharing of data with origin Tensor, There are some unsafe operations: # Due to sharing of data with origin Tensor, There are some unsafe operations:
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册