diff --git a/examples/A2C/actor.py b/examples/A2C/actor.py index cf026747f7eb0d06704332aeb5fc9ef64eb852c9..e85fbc00cc4a0ca962dd4315525e32f99fc68c61 100644 --- a/examples/A2C/actor.py +++ b/examples/A2C/actor.py @@ -48,12 +48,7 @@ class Actor(object): model = AtariModel(act_dim) algorithm = parl.algorithms.A3C( model, vf_loss_coeff=config['vf_loss_coeff']) - self.agent = AtariAgent( - algorithm, - obs_shape=self.config['obs_shape'], - lr_scheduler=self.config['lr_scheduler'], - entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'], - ) + self.agent = AtariAgent(algorithm, config) def sample(self): sample_data = defaultdict(list) diff --git a/examples/A2C/atari_agent.py b/examples/A2C/atari_agent.py index 1c0e47772d1c747bcd91dc6bd8e7f9984c0f9010..126f403a1f5163f88564c030c5b442ff3b63eda5 100644 --- a/examples/A2C/atari_agent.py +++ b/examples/A2C/atari_agent.py @@ -21,30 +21,22 @@ from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler class AtariAgent(parl.Agent): - def __init__(self, algorithm, obs_shape, lr_scheduler, - entropy_coeff_scheduler): + def __init__(self, algorithm, config): """ Args: - algorithm (`parl.Algorithm`): a2c algorithm - obs_shape (list/tuple): observation shape of atari environment - lr_scheduler (list/tuple): learning rate adjustment schedule: (train_step, learning_rate) - entropy_coeff_scheduler (list/tuple): coefficient of policy entropy adjustment schedule: (train_step, coefficient) + algorithm (`parl.Algorithm`): algorithm to be used in this agent. + config (dict): config file describing the training hyper-parameters(see a2c_config.py) """ - assert isinstance(obs_shape, (list, tuple)) - assert isinstance(lr_scheduler, (list, tuple)) - assert isinstance(entropy_coeff_scheduler, (list, tuple)) - self.obs_shape = obs_shape - self.lr_scheduler = lr_scheduler - self.entropy_coeff_scheduler = entropy_coeff_scheduler + self.obs_shape = config['obs_shape'] super(AtariAgent, self).__init__(algorithm) self.lr_scheduler = LinearDecayScheduler(config['start_lr'], config['max_sample_steps']) self.entropy_coeff_scheduler = PiecewiseScheduler( - self.entropy_coeff_scheduler) + config['entropy_coeff_scheduler']) exec_strategy = fluid.ExecutionStrategy() exec_strategy.use_experimental_executor = True diff --git a/examples/A2C/learner.py b/examples/A2C/learner.py index b6b034e1059ab3c57e762e63a153087e95ec72db..49a77ddfb4d3da5319417e6349f132c1304ee042 100644 --- a/examples/A2C/learner.py +++ b/examples/A2C/learner.py @@ -47,12 +47,7 @@ class Learner(object): model = AtariModel(act_dim) algorithm = parl.algorithms.A3C( model, vf_loss_coeff=config['vf_loss_coeff']) - self.agent = AtariAgent( - algorithm, - obs_shape=self.config['obs_shape'], - lr_scheduler=self.config['lr_scheduler'], - entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'], - ) + self.agent = AtariAgent(algorithm, config) if machine_info.is_gpu_available(): assert get_gpu_count() == 1, 'Only support training in single GPU,\