提交 255ef4f7 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

refine A2C example (#80)

* refine A2C example

* fix unittest in python2; fix codestyle

* fix codestyle

* refine comment
上级 c3b34fd9
......@@ -7,10 +7,12 @@ A2C is a synchronous, deterministic variant of [Asynchronous Advantage Actor Cri
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
Results with one learner (in a P40 GPU) and 5 actors in 10 million sample steps.
<img src=".benchmark/A2C_Pong.jpg" width = "400" height ="300" alt="A2C_Pong" /> <img src=".benchmark/A2C_Breakout.jpg" width = "400" height ="300" alt="A2C_Breakout"/>
<img src=".benchmark/A2C_BeamRider.jpg" width = "400" height ="300" alt="A2C_BeamRider" /> <img src=".benchmark/A2C_Qbert.jpg" width = "400" height ="300" alt="A2C_Qbert"/>
<img src=".benchmark/A2C_SpaceInvaders.jpg" width = "400" height ="300" alt="A2C_SpaceInvaders" />
Mean episode reward in training process after 10 million sample steps.
| | | | | |
|--------------|----------------|------------------|---------------|---------------------|
| Alien (1278) | Amidar (380) | Assault (4659) | Aterix (3883) | Atlantis (3040000) |
| Pong (20) | Breakout (405) | Beamrider (3394) | Qbert (14528) | SpaceInvaders (819) |
## How to use
### Dependencies
......
......@@ -19,7 +19,7 @@ config = {
#========== env config ==========
'env_name': 'PongNoFrameskip-v4',
'env_dim': 42,
'env_dim': 84,
#========== actor config ==========
'actor_num': 5,
......@@ -27,11 +27,12 @@ config = {
'sample_batch_steps': 20,
#========== learner config ==========
'max_sample_steps': int(1e7),
'gamma': 0.99,
'lambda': 1.0, # GAE
# learning rate adjustment schedule: (train_step, learning_rate)
'lr_scheduler': [(0, 0.001), (20000, 0.0005), (40000, 0.0001)],
# start learning rate
'start_lr': 0.001,
# coefficient of policy entropy adjustment schedule: (train_step, coefficient)
'entropy_coeff_scheduler': [(0, -0.01)],
......
......@@ -16,7 +16,7 @@ import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
class AtariAgent(Agent):
......@@ -24,7 +24,8 @@ class AtariAgent(Agent):
self.config = config
super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
......@@ -150,7 +151,7 @@ class AtariAgent(Agent):
advantages_np = advantages_np.astype('float32')
target_values_np = target_values_np.astype('float32')
lr = self.lr_scheduler.step()
lr = self.lr_scheduler.step(step_num=obs_np.shape[0])
entropy_coeff = self.entropy_coeff_scheduler.step()
total_loss, pi_loss, vf_loss, entropy = self.learn_exe.run(
......
......@@ -22,23 +22,16 @@ class AtariModel(Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
num_filters=16, filter_size=4, stride=2, padding=1, act='relu')
num_filters=32, filter_size=8, stride=4, padding=1, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=4, stride=2, padding=2, act='relu')
num_filters=64, filter_size=4, stride=2, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=256, filter_size=11, stride=1, padding=0, act='relu')
num_filters=64, filter_size=3, stride=1, padding=0, act='relu')
self.policy_conv = layers.conv2d(
num_filters=act_dim,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
self.fc = layers.fc(size=512, act='relu')
self.value_fc = layers.fc(
size=1,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
self.policy_fc = layers.fc(size=act_dim)
self.value_fc = layers.fc(size=1)
def policy(self, obs):
"""
......@@ -53,8 +46,10 @@ class AtariModel(Model):
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
return policy_logits
def value(self, obs):
......@@ -71,7 +66,9 @@ class AtariModel(Model):
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten)
fc_output = self.fc(flatten)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return values
......@@ -89,11 +86,12 @@ class AtariModel(Model):
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten)
fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return policy_logits, values
......@@ -209,5 +209,8 @@ class Learner(object):
logger.info(metric)
self.csv_logger.log_dict(metric)
def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps']
def close(self):
self.csv_logger.close()
......@@ -21,11 +21,12 @@ def main(config):
assert config['log_metrics_interval_s'] > 0
try:
while True:
while not learner.should_stop():
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
learner.close()
except KeyboardInterrupt:
learner.close()
......
......@@ -14,12 +14,15 @@
import six
__all__ = ['PiecewiseScheduler']
__all__ = ['PiecewiseScheduler', 'LinearDecayScheduler']
class PiecewiseScheduler(object):
"""Set hyper parameters by a predefined step-based scheduler.
"""
def __init__(self, scheduler_list):
""" Piecewise scheduler of hyper parameter.
"""Piecewise scheduler of hyper parameter.
Args:
scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)]
......@@ -38,18 +41,58 @@ class PiecewiseScheduler(object):
self.scheduler_num = len(self.scheduler_list)
def step(self):
""" Step one and fetch value according to following rule:
def step(self, step_num=1):
"""Step step_num and fetch value according to following rule:
Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)],
function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1
Args:
step_num (int): number of steps (default: 1)
"""
assert isinstance(step_num, int) and step_num >= 1
self.cur_step += step_num
if self.cur_index < self.scheduler_num - 1:
if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]:
self.cur_index += 1
self.cur_value = self.scheduler_list[self.cur_index][1]
self.cur_step += 1
return self.cur_value
class LinearDecayScheduler(object):
"""Set hyper parameters by a step-based scheduler with linear decay values.
"""
def __init__(self, start_value, max_steps):
"""Linear decay scheduler of hyper parameter.
Args:
start_value (float): start value
max_steps (int): maximum steps
"""
assert max_steps > 0
self.cur_step = 0
self.max_steps = max_steps
self.start_value = start_value
def step(self, step_num=1):
"""Step step_num and fetch value according to following rule:
return_value = start_value * (1.0 - (cur_steps / max_steps))
Args:
step_num (int): number of steps (default: 1)
Returns:
value (float): current value
"""
assert isinstance(step_num, int) and step_num >= 1
self.cur_step = min(self.cur_step + step_num, self.max_steps)
value = self.start_value * (1.0 - (
(self.cur_step * 1.0) / self.max_steps))
return value
......@@ -13,13 +13,14 @@
# limitations under the License.
import unittest
from parl.utils.scheduler import PiecewiseScheduler
import numpy as np
from parl.utils.scheduler import *
class TestScheduler(unittest.TestCase):
def test_PiecewiseScheduler_with_multi_values(self):
scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)])
for i in range(10):
for i in range(1, 11):
value = scheduler.step()
if i < 3:
assert value == 0.1
......@@ -39,6 +40,18 @@ class TestScheduler(unittest.TestCase):
value = scheduler.step()
assert value == 0.1
def test_PiecewiseScheduler_with_step_num(self):
scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)])
value = scheduler.step()
assert value == 0.1
value = scheduler.step(2)
assert value == 0.2
value = scheduler.step(4)
assert value == 0.3
def test_PiecewiseScheduler_with_empty(self):
try:
scheduler = PiecewiseScheduler([])
......@@ -55,6 +68,24 @@ class TestScheduler(unittest.TestCase):
return
assert False
def test_LinearDecayScheduler(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
for i in range(10):
value = scheduler.step()
np.testing.assert_almost_equal(value, 10 - (i + 1), 8)
for i in range(5):
value = scheduler.step()
np.testing.assert_almost_equal(value, 0, 8)
def test_LinearDecayScheduler_with_step_num(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
value = scheduler.step(5)
np.testing.assert_almost_equal(value, 5, 8)
value = scheduler.step(3)
np.testing.assert_almost_equal(value, 2, 8)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册