未验证 提交 33516338 编写于 作者: B Bo Zhou 提交者: GitHub

fix the compatibility issue in the A2C example. (#98)

* fix the compatibility issue

* fix the comment issue
上级 d18f19a9
......@@ -48,12 +48,7 @@ class Actor(object):
model = AtariModel(act_dim)
algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent(
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
self.agent = AtariAgent(algorithm, config)
def sample(self):
sample_data = defaultdict(list)
......
......@@ -21,30 +21,22 @@ from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
class AtariAgent(parl.Agent):
def __init__(self, algorithm, obs_shape, lr_scheduler,
entropy_coeff_scheduler):
def __init__(self, algorithm, config):
"""
Args:
algorithm (`parl.Algorithm`): a2c algorithm
obs_shape (list/tuple): observation shape of atari environment
lr_scheduler (list/tuple): learning rate adjustment schedule: (train_step, learning_rate)
entropy_coeff_scheduler (list/tuple): coefficient of policy entropy adjustment schedule: (train_step, coefficient)
algorithm (`parl.Algorithm`): algorithm to be used in this agent.
config (dict): config file describing the training hyper-parameters(see a2c_config.py)
"""
assert isinstance(obs_shape, (list, tuple))
assert isinstance(lr_scheduler, (list, tuple))
assert isinstance(entropy_coeff_scheduler, (list, tuple))
self.obs_shape = obs_shape
self.lr_scheduler = lr_scheduler
self.entropy_coeff_scheduler = entropy_coeff_scheduler
self.obs_shape = config['obs_shape']
super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
self.entropy_coeff_scheduler)
config['entropy_coeff_scheduler'])
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
......
......@@ -47,12 +47,7 @@ class Learner(object):
model = AtariModel(act_dim)
algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent(
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
self.agent = AtariAgent(algorithm, config)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册