未验证 提交 89c3366b 编写于 作者: H Hongsheng Zeng 提交者: GitHub

Limit impala to single GPU training (#152)

* Limit impala to single GPU training

* refine comment of scheduler

* refine comment
上级 2ce7d216
......@@ -53,7 +53,7 @@ class Learner(object):
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
else:
cpu_num = os.environ.get('CPU_NUM')
......
......@@ -22,10 +22,11 @@ import parl
from atari_model import AtariModel
from atari_agent import AtariAgent
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, tensorboard
from parl.utils import logger, tensorboard, get_gpu_count
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat
from parl.utils import machine_info
from actor import Actor
......@@ -54,6 +55,10 @@ class Learner(object):
self.agent = AtariAgent(algorithm, obs_shape, act_dim,
self.learn_data_provider)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
self.cache_params = self.agent.get_weights()
self.params_lock = threading.Lock()
self.params_updated = False
......
......@@ -67,7 +67,7 @@ class LinearDecayScheduler(object):
def __init__(self, start_value, max_steps):
"""Linear decay scheduler of hyper parameter.
Decay value linearly untill 0.
Args:
start_value (float): start value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册