未验证 提交 33516338 编写于 作者: B Bo Zhou 提交者: GitHub

fix the compatibility issue in the A2C example. (#98)

* fix the compatibility issue

* fix the comment issue
上级 d18f19a9
...@@ -48,12 +48,7 @@ class Actor(object): ...@@ -48,12 +48,7 @@ class Actor(object):
model = AtariModel(act_dim) model = AtariModel(act_dim)
algorithm = parl.algorithms.A3C( algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff']) model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent( self.agent = AtariAgent(algorithm, config)
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
def sample(self): def sample(self):
sample_data = defaultdict(list) sample_data = defaultdict(list)
......
...@@ -21,30 +21,22 @@ from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler ...@@ -21,30 +21,22 @@ from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
class AtariAgent(parl.Agent): class AtariAgent(parl.Agent):
def __init__(self, algorithm, obs_shape, lr_scheduler, def __init__(self, algorithm, config):
entropy_coeff_scheduler):
""" """
Args: Args:
algorithm (`parl.Algorithm`): a2c algorithm algorithm (`parl.Algorithm`): algorithm to be used in this agent.
obs_shape (list/tuple): observation shape of atari environment config (dict): config file describing the training hyper-parameters(see a2c_config.py)
lr_scheduler (list/tuple): learning rate adjustment schedule: (train_step, learning_rate)
entropy_coeff_scheduler (list/tuple): coefficient of policy entropy adjustment schedule: (train_step, coefficient)
""" """
assert isinstance(obs_shape, (list, tuple))
assert isinstance(lr_scheduler, (list, tuple))
assert isinstance(entropy_coeff_scheduler, (list, tuple))
self.obs_shape = obs_shape
self.lr_scheduler = lr_scheduler
self.entropy_coeff_scheduler = entropy_coeff_scheduler
self.obs_shape = config['obs_shape']
super(AtariAgent, self).__init__(algorithm) super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = LinearDecayScheduler(config['start_lr'], self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps']) config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler( self.entropy_coeff_scheduler = PiecewiseScheduler(
self.entropy_coeff_scheduler) config['entropy_coeff_scheduler'])
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True exec_strategy.use_experimental_executor = True
......
...@@ -47,12 +47,7 @@ class Learner(object): ...@@ -47,12 +47,7 @@ class Learner(object):
model = AtariModel(act_dim) model = AtariModel(act_dim)
algorithm = parl.algorithms.A3C( algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff']) model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent( self.agent = AtariAgent(algorithm, config)
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
if machine_info.is_gpu_available(): if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\ assert get_gpu_count() == 1, 'Only support training in single GPU,\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册