From 41ad3a603b61f3f2a87d7312416afc67b32d0e8a Mon Sep 17 00:00:00 2001 From: TomorrowIsAnOtherDay <2466956298@qq.com> Date: Thu, 20 Aug 2020 12:30:05 +0800 Subject: [PATCH] update paddle_gpu_test.py --- parl/remote/tests/paddle_gpu_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/parl/remote/tests/paddle_gpu_test.py b/parl/remote/tests/paddle_gpu_test.py index 256814c..d2d433a 100644 --- a/parl/remote/tests/paddle_gpu_test.py +++ b/parl/remote/tests/paddle_gpu_test.py @@ -28,8 +28,8 @@ import os @parl.remote_class class Actor(object): - def __init__(self): - if parl.utils.is_gpu_available(): + def __init__(self, cuda=False): + if cuda: place = fluid.CUDAPlace(0) else: place = fluid.CPUPlace() @@ -51,7 +51,10 @@ class TestCluster(unittest.TestCase): parl.connect('localhost:8241') - actor = Actor() + if parl.utils.is_gpu_available(): + actor = Actor(cuda=True) + else: + actor = Actor(cuda=False) del actor master.exit() -- GitLab