提交 255ef4f7 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

refine A2C example (#80)

* refine A2C example

* fix unittest in python2; fix codestyle

* fix codestyle

* refine comment
上级 c3b34fd9
...@@ -7,10 +7,12 @@ A2C is a synchronous, deterministic variant of [Asynchronous Advantage Actor Cri ...@@ -7,10 +7,12 @@ A2C is a synchronous, deterministic variant of [Asynchronous Advantage Actor Cri
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games. Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result ### Benchmark result
Results with one learner (in a P40 GPU) and 5 actors in 10 million sample steps. Mean episode reward in training process after 10 million sample steps.
<img src=".benchmark/A2C_Pong.jpg" width = "400" height ="300" alt="A2C_Pong" /> <img src=".benchmark/A2C_Breakout.jpg" width = "400" height ="300" alt="A2C_Breakout"/>
<img src=".benchmark/A2C_BeamRider.jpg" width = "400" height ="300" alt="A2C_BeamRider" /> <img src=".benchmark/A2C_Qbert.jpg" width = "400" height ="300" alt="A2C_Qbert"/> | | | | | |
<img src=".benchmark/A2C_SpaceInvaders.jpg" width = "400" height ="300" alt="A2C_SpaceInvaders" /> |--------------|----------------|------------------|---------------|---------------------|
| Alien (1278) | Amidar (380) | Assault (4659) | Aterix (3883) | Atlantis (3040000) |
| Pong (20) | Breakout (405) | Beamrider (3394) | Qbert (14528) | SpaceInvaders (819) |
## How to use ## How to use
### Dependencies ### Dependencies
......
...@@ -19,7 +19,7 @@ config = { ...@@ -19,7 +19,7 @@ config = {
#========== env config ========== #========== env config ==========
'env_name': 'PongNoFrameskip-v4', 'env_name': 'PongNoFrameskip-v4',
'env_dim': 42, 'env_dim': 84,
#========== actor config ========== #========== actor config ==========
'actor_num': 5, 'actor_num': 5,
...@@ -27,11 +27,12 @@ config = { ...@@ -27,11 +27,12 @@ config = {
'sample_batch_steps': 20, 'sample_batch_steps': 20,
#========== learner config ========== #========== learner config ==========
'max_sample_steps': int(1e7),
'gamma': 0.99, 'gamma': 0.99,
'lambda': 1.0, # GAE 'lambda': 1.0, # GAE
# learning rate adjustment schedule: (train_step, learning_rate) # start learning rate
'lr_scheduler': [(0, 0.001), (20000, 0.0005), (40000, 0.0001)], 'start_lr': 0.001,
# coefficient of policy entropy adjustment schedule: (train_step, coefficient) # coefficient of policy entropy adjustment schedule: (train_step, coefficient)
'entropy_coeff_scheduler': [(0, -0.01)], 'entropy_coeff_scheduler': [(0, -0.01)],
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import parl.layers as layers import parl.layers as layers
from parl.framework.agent_base import Agent from parl.framework.agent_base import Agent
from parl.utils.scheduler import PiecewiseScheduler from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
class AtariAgent(Agent): class AtariAgent(Agent):
...@@ -24,7 +24,8 @@ class AtariAgent(Agent): ...@@ -24,7 +24,8 @@ class AtariAgent(Agent):
self.config = config self.config = config
super(AtariAgent, self).__init__(algorithm) super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler']) self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler( self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler']) config['entropy_coeff_scheduler'])
...@@ -150,7 +151,7 @@ class AtariAgent(Agent): ...@@ -150,7 +151,7 @@ class AtariAgent(Agent):
advantages_np = advantages_np.astype('float32') advantages_np = advantages_np.astype('float32')
target_values_np = target_values_np.astype('float32') target_values_np = target_values_np.astype('float32')
lr = self.lr_scheduler.step() lr = self.lr_scheduler.step(step_num=obs_np.shape[0])
entropy_coeff = self.entropy_coeff_scheduler.step() entropy_coeff = self.entropy_coeff_scheduler.step()
total_loss, pi_loss, vf_loss, entropy = self.learn_exe.run( total_loss, pi_loss, vf_loss, entropy = self.learn_exe.run(
......
...@@ -22,23 +22,16 @@ class AtariModel(Model): ...@@ -22,23 +22,16 @@ class AtariModel(Model):
def __init__(self, act_dim): def __init__(self, act_dim):
self.conv1 = layers.conv2d( self.conv1 = layers.conv2d(
num_filters=16, filter_size=4, stride=2, padding=1, act='relu') num_filters=32, filter_size=8, stride=4, padding=1, act='relu')
self.conv2 = layers.conv2d( self.conv2 = layers.conv2d(
num_filters=32, filter_size=4, stride=2, padding=2, act='relu') num_filters=64, filter_size=4, stride=2, padding=2, act='relu')
self.conv3 = layers.conv2d( self.conv3 = layers.conv2d(
num_filters=256, filter_size=11, stride=1, padding=0, act='relu') num_filters=64, filter_size=3, stride=1, padding=0, act='relu')
self.policy_conv = layers.conv2d( self.fc = layers.fc(size=512, act='relu')
num_filters=act_dim,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
self.value_fc = layers.fc( self.policy_fc = layers.fc(size=act_dim)
size=1, self.value_fc = layers.fc(size=1)
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
def policy(self, obs): def policy(self, obs):
""" """
...@@ -53,8 +46,10 @@ class AtariModel(Model): ...@@ -53,8 +46,10 @@ class AtariModel(Model):
conv2 = self.conv2(conv1) conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2) conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3) flatten = layers.flatten(conv3, axis=1)
policy_logits = layers.flatten(policy_conv, axis=1) fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
return policy_logits return policy_logits
def value(self, obs): def value(self, obs):
...@@ -71,7 +66,9 @@ class AtariModel(Model): ...@@ -71,7 +66,9 @@ class AtariModel(Model):
conv3 = self.conv3(conv2) conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1) flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten) fc_output = self.fc(flatten)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1]) values = layers.squeeze(values, axes=[1])
return values return values
...@@ -89,11 +86,12 @@ class AtariModel(Model): ...@@ -89,11 +86,12 @@ class AtariModel(Model):
conv2 = self.conv2(conv1) conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2) conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
flatten = layers.flatten(conv3, axis=1) flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten) fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1]) values = layers.squeeze(values, axes=[1])
return policy_logits, values return policy_logits, values
...@@ -209,5 +209,8 @@ class Learner(object): ...@@ -209,5 +209,8 @@ class Learner(object):
logger.info(metric) logger.info(metric)
self.csv_logger.log_dict(metric) self.csv_logger.log_dict(metric)
def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps']
def close(self): def close(self):
self.csv_logger.close() self.csv_logger.close()
...@@ -21,11 +21,12 @@ def main(config): ...@@ -21,11 +21,12 @@ def main(config):
assert config['log_metrics_interval_s'] > 0 assert config['log_metrics_interval_s'] > 0
try: try:
while True: while not learner.should_stop():
start = time.time() start = time.time()
while time.time() - start < config['log_metrics_interval_s']: while time.time() - start < config['log_metrics_interval_s']:
learner.step() learner.step()
learner.log_metrics() learner.log_metrics()
learner.close()
except KeyboardInterrupt: except KeyboardInterrupt:
learner.close() learner.close()
......
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
import six import six
__all__ = ['PiecewiseScheduler'] __all__ = ['PiecewiseScheduler', 'LinearDecayScheduler']
class PiecewiseScheduler(object): class PiecewiseScheduler(object):
"""Set hyper parameters by a predefined step-based scheduler.
"""
def __init__(self, scheduler_list): def __init__(self, scheduler_list):
""" Piecewise scheduler of hyper parameter. """Piecewise scheduler of hyper parameter.
Args: Args:
scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)] scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)]
...@@ -38,18 +41,58 @@ class PiecewiseScheduler(object): ...@@ -38,18 +41,58 @@ class PiecewiseScheduler(object):
self.scheduler_num = len(self.scheduler_list) self.scheduler_num = len(self.scheduler_list)
def step(self): def step(self, step_num=1):
""" Step one and fetch value according to following rule: """Step step_num and fetch value according to following rule:
Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)], Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)],
function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1 function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1
Args:
step_num (int): number of steps (default: 1)
""" """
assert isinstance(step_num, int) and step_num >= 1
self.cur_step += step_num
if self.cur_index < self.scheduler_num - 1: if self.cur_index < self.scheduler_num - 1:
if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]: if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]:
self.cur_index += 1 self.cur_index += 1
self.cur_value = self.scheduler_list[self.cur_index][1] self.cur_value = self.scheduler_list[self.cur_index][1]
self.cur_step += 1
return self.cur_value return self.cur_value
class LinearDecayScheduler(object):
"""Set hyper parameters by a step-based scheduler with linear decay values.
"""
def __init__(self, start_value, max_steps):
"""Linear decay scheduler of hyper parameter.
Args:
start_value (float): start value
max_steps (int): maximum steps
"""
assert max_steps > 0
self.cur_step = 0
self.max_steps = max_steps
self.start_value = start_value
def step(self, step_num=1):
"""Step step_num and fetch value according to following rule:
return_value = start_value * (1.0 - (cur_steps / max_steps))
Args:
step_num (int): number of steps (default: 1)
Returns:
value (float): current value
"""
assert isinstance(step_num, int) and step_num >= 1
self.cur_step = min(self.cur_step + step_num, self.max_steps)
value = self.start_value * (1.0 - (
(self.cur_step * 1.0) / self.max_steps))
return value
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from parl.utils.scheduler import PiecewiseScheduler import numpy as np
from parl.utils.scheduler import *
class TestScheduler(unittest.TestCase): class TestScheduler(unittest.TestCase):
def test_PiecewiseScheduler_with_multi_values(self): def test_PiecewiseScheduler_with_multi_values(self):
scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)]) scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)])
for i in range(10): for i in range(1, 11):
value = scheduler.step() value = scheduler.step()
if i < 3: if i < 3:
assert value == 0.1 assert value == 0.1
...@@ -39,6 +40,18 @@ class TestScheduler(unittest.TestCase): ...@@ -39,6 +40,18 @@ class TestScheduler(unittest.TestCase):
value = scheduler.step() value = scheduler.step()
assert value == 0.1 assert value == 0.1
def test_PiecewiseScheduler_with_step_num(self):
scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)])
value = scheduler.step()
assert value == 0.1
value = scheduler.step(2)
assert value == 0.2
value = scheduler.step(4)
assert value == 0.3
def test_PiecewiseScheduler_with_empty(self): def test_PiecewiseScheduler_with_empty(self):
try: try:
scheduler = PiecewiseScheduler([]) scheduler = PiecewiseScheduler([])
...@@ -55,6 +68,24 @@ class TestScheduler(unittest.TestCase): ...@@ -55,6 +68,24 @@ class TestScheduler(unittest.TestCase):
return return
assert False assert False
def test_LinearDecayScheduler(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
for i in range(10):
value = scheduler.step()
np.testing.assert_almost_equal(value, 10 - (i + 1), 8)
for i in range(5):
value = scheduler.step()
np.testing.assert_almost_equal(value, 0, 8)
def test_LinearDecayScheduler_with_step_num(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
value = scheduler.step(5)
np.testing.assert_almost_equal(value, 5, 8)
value = scheduler.step(3)
np.testing.assert_almost_equal(value, 2, 8)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册