未验证 提交 89c3366b 编写于 作者: H Hongsheng Zeng 提交者: GitHub

Limit impala to single GPU training (#152)

* Limit impala to single GPU training

* refine comment of scheduler

* refine comment
上级 2ce7d216
...@@ -53,7 +53,7 @@ class Learner(object): ...@@ -53,7 +53,7 @@ class Learner(object):
if machine_info.is_gpu_available(): if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\ assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .' Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
else: else:
cpu_num = os.environ.get('CPU_NUM') cpu_num = os.environ.get('CPU_NUM')
......
...@@ -22,10 +22,11 @@ import parl ...@@ -22,10 +22,11 @@ import parl
from atari_model import AtariModel from atari_model import AtariModel
from atari_agent import AtariAgent from atari_agent import AtariAgent
from parl.env.atari_wrappers import wrap_deepmind from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, tensorboard from parl.utils import logger, tensorboard, get_gpu_count
from parl.utils.scheduler import PiecewiseScheduler from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat from parl.utils.window_stat import WindowStat
from parl.utils import machine_info
from actor import Actor from actor import Actor
...@@ -54,6 +55,10 @@ class Learner(object): ...@@ -54,6 +55,10 @@ class Learner(object):
self.agent = AtariAgent(algorithm, obs_shape, act_dim, self.agent = AtariAgent(algorithm, obs_shape, act_dim,
self.learn_data_provider) self.learn_data_provider)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` .'
self.cache_params = self.agent.get_weights() self.cache_params = self.agent.get_weights()
self.params_lock = threading.Lock() self.params_lock = threading.Lock()
self.params_updated = False self.params_updated = False
......
...@@ -67,7 +67,7 @@ class LinearDecayScheduler(object): ...@@ -67,7 +67,7 @@ class LinearDecayScheduler(object):
def __init__(self, start_value, max_steps): def __init__(self, start_value, max_steps):
"""Linear decay scheduler of hyper parameter. """Linear decay scheduler of hyper parameter.
Decay value linearly untill 0.
Args: Args:
start_value (float): start value start_value (float): start value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册