提交 60a6867b 编写于 作者: N niuyazhe

hotfix(nyz): fix random policy typo in serial entry and base policy model device problem

上级 020eba28
......@@ -82,7 +82,7 @@ def serial_pipeline(
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(policy.collect_mode)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
......
......@@ -75,14 +75,14 @@ class Policy(ABC):
if len(set(self._enable_field).intersection(set(['learn']))) > 0:
self._rank = get_rank() if self._cfg.learn.multi_gpu else 0
if self._cuda:
torch.cuda.set_device(self._rank)
torch.cuda.set_device(self._rank % torch.cuda.device_count())
model.cuda()
if self._cfg.learn.multi_gpu:
self._init_multi_gpu_setting(model)
else:
self._rank = 0
if self._cuda:
torch.cuda.set_device(self._rank)
torch.cuda.set_device(self._rank % torch.cuda.device_count())
model.cuda()
self._model = model
self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册