未验证 提交 21a9efed 编写于 作者: H Haonan 提交者: GitHub

added test_simple_games (#15)

added test_simple_games
上级 4b4b5824
......@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.framework.algorithm import RLAlgorithm
from parl.framework.algorithm import Algorithm
from paddle.fluid.initializer import ConstantInitializer
import parl.layers as layers
import parl.framework.policy_distribution as pd
from parl.layers import common_functions as comf
......@@ -20,7 +21,7 @@ import paddle.fluid as fluid
from copy import deepcopy
class SimpleAC(RLAlgorithm):
class SimpleAC(Algorithm):
"""
A simple Actor-Critic that has a feedforward policy network and
a single discrete action.
......@@ -37,7 +38,7 @@ class SimpleAC(RLAlgorithm):
super(SimpleAC, self).__init__(model, hyperparas, gpu_id)
self.discount_factor = discount_factor
def learn(self, inputs, next_inputs, states, next_states, episode_end,
def learn(self, inputs, next_inputs, states, next_states, next_episode_end,
actions, rewards):
action = actions["action"]
......@@ -46,7 +47,8 @@ class SimpleAC(RLAlgorithm):
values = self.model.value(inputs, states)
next_values = self.model.value(next_inputs, next_states)
value = values["v_value"]
next_value = next_values["v_value"] * episode_end["episode_end"]
next_value = next_values["v_value"] * next_episode_end[
"next_episode_end"]
next_value.stop_gradient = True
assert value.shape[1] == next_value.shape[1]
......@@ -60,12 +62,16 @@ class SimpleAC(RLAlgorithm):
pg_cost = 0 - dist.loglikelihood(action)
avg_cost = layers.mean(x=value_cost + pg_cost * td_error)
optimizer = fluid.optimizer.SGD(learning_rate=self.hp["lr"])
optimizer = fluid.optimizer.DecayedAdagradOptimizer(
learning_rate=self.hp["lr"])
optimizer.minimize(avg_cost)
return dict(cost=avg_cost)
def predict(self, inputs, states):
return self._rl_predict(self.model, inputs, states)
class SimpleQ(RLAlgorithm):
class SimpleQ(Algorithm):
"""
A simple Q-learning that has a feedforward policy network and a single discrete action.
......@@ -77,6 +83,8 @@ class SimpleQ(RLAlgorithm):
hyperparas=dict(lr=1e-4),
gpu_id=-1,
discount_factor=0.99,
exploration_end_batches=0,
exploration_end_rate=0.1,
update_ref_interval=100):
super(SimpleQ, self).__init__(model, hyperparas, gpu_id)
......@@ -87,13 +95,45 @@ class SimpleQ(RLAlgorithm):
self.total_batches = 0
## create a reference model
self.ref_model = deepcopy(model)
## setup exploration
self.explore = (exploration_end_batches > 0)
if self.explore:
self.exploration_counter = layers.create_persistable_variable(
dtype="float32",
shape=[1],
is_bias=True,
default_initializer=ConstantInitializer(0.))
### in the second half of training time, the rate is fixed to a number
self.total_exploration_batches = exploration_end_batches
self.exploration_rate_delta \
= (1 - exploration_end_rate) / self.total_exploration_batches
def before_every_batch(self):
if self.total_batches % self.update_ref_interval == 0:
self.model.sync_paras_to(self.ref_model, self.gpu_id)
self.total_batches += 1
def learn(self, inputs, next_inputs, states, next_states, episode_end,
def predict(self, inputs, states):
"""
Override the base predict() function to put the exploration rate in inputs
"""
rate = 0
if self.explore:
counter = self.exploration_counter()
## first compute the current exploration rate
rate = 1 - counter * self.exploration_rate_delta
distributions, states = self.model.policy(inputs, states)
for dist in distributions.values():
assert dist.__class__.__name__ == "CategoricalDistribution"
dist.add_uniform_exploration(rate)
actions = {}
for key, dist in distributions.iteritems():
actions[key] = dist()
return actions, states
def learn(self, inputs, next_inputs, states, next_states, next_episode_end,
actions, rewards):
action = actions["action"]
......@@ -102,7 +142,8 @@ class SimpleQ(RLAlgorithm):
values = self.model.value(inputs, states)
next_values = self.ref_model.value(next_inputs, next_states)
q_value = values["q_value"]
next_q_value = next_values["q_value"] * episode_end["episode_end"]
next_q_value = next_values["q_value"] * next_episode_end[
"next_episode_end"]
next_q_value.stop_gradient = True
next_value = layers.reduce_max(next_q_value, dim=-1)
assert q_value.shape[1] == next_q_value.shape[1]
......@@ -113,6 +154,20 @@ class SimpleQ(RLAlgorithm):
td_error = critic_value - value
avg_cost = layers.mean(x=layers.square(td_error))
optimizer = fluid.optimizer.SGD(learning_rate=self.hp["lr"])
optimizer = fluid.optimizer.DecayedAdagradOptimizer(
learning_rate=self.hp["lr"])
optimizer.minimize(avg_cost)
self._increment_exploration_counter()
return dict(cost=avg_cost)
def _increment_exploration_counter(self):
if self.explore:
counter = self.exploration_counter()
exploration_counter_ = counter + 1
switch = layers.cast(
x=(exploration_counter_ > self.total_exploration_batches),
dtype="float32")
## if the counter already hits the limit, we do not change the counter
layers.assign(switch * counter +
(1 - switch) * exploration_counter_, counter)
......@@ -35,7 +35,7 @@ def check_duplicate_spec_names(model):
class Model(Network):
"""
A Model is owned by an Algorithm. It implements all the network model of
A Model is owned by an Algorithm. It implements the entire network model of
a specific problem.
"""
__metaclass__ = ABCMeta
......@@ -142,11 +142,28 @@ class Algorithm(object):
def predict(self, inputs, states):
"""
Given the inputs and states, this function does forward prediction and updates states.
Input: inputs(dict), states(dict)
Output: actions(dict), states(dict)
Optional: an algorithm might not implement predict()
"""
pass
def learn(self, inputs, next_inputs, states, next_states, episode_end,
def _rl_predict(self, behavior_model, inputs, states):
"""
Given a behavior model (not necessarily equal to self.model), this function
performs a normal RL prediction according to inputs and states.
A behavior model different from self.model indicates off-policy training.
The user can choose to call this function for convenience.
"""
distributions, states = behavior_model.policy(inputs, states)
actions = {}
for key, dist in distributions.iteritems():
actions[key] = dist()
return actions, states
def learn(self, inputs, next_inputs, states, next_states, next_episode_end,
actions, rewards):
"""
This function computes a learning cost to be optimized.
......@@ -156,40 +173,3 @@ class Algorithm(object):
Optional: an algorithm might not implement learn()
"""
pass
class RLAlgorithm(Algorithm):
"""
A derived Algorithm class specially for RL problems.
"""
def __init__(self, model, hyperparas, gpu_id):
super(RLAlgorithm, self).__init__(model, hyperparas, gpu_id)
def get_behavior_model(self):
"""
Return the behavior model to compute actions. The behavior model could be different
from the training model, which is common in off-policy RL algorithms.
The default behavior model is set to the training model. The user can override this
function to specify another different model.
"""
return self.model
def predict(self, inputs, states):
"""
Implementation of Algorithm.predict()
Given the inputs and states, this function predicts actions and updates states.
Input: inputs(dict), states(dict)
Output: actions(dict), states(dict)
"""
behavior_model = self.get_behavior_model()
distributions, states = behavior_model.policy(inputs, states)
actions = {}
for key, dist in distributions.iteritems():
assert isinstance(
dist, pd.PolicyDistribution
), "behavior_model.policy must return PolicyDist!"
actions[key] = dist()
return actions, states
......@@ -72,7 +72,7 @@ class ComputationTask(object):
next_state_specs = _get_next_specs(state_specs)
action_specs = self.alg.get_action_specs()
reward_specs = self.alg.get_reward_specs()
episode_end_specs = [("episode_end", dict(shape=[1]))]
next_episode_end_specs = [("next_episode_end", dict(shape=[1]))]
self.action_names = sorted([name for name, _ in action_specs])
self.state_names = sorted([name for name, _ in state_specs])
......@@ -96,7 +96,8 @@ class ComputationTask(object):
data_layer_dict.update(self._create_data_layers(next_state_specs))
data_layer_dict.update(self._create_data_layers(action_specs))
data_layer_dict.update(self._create_data_layers(reward_specs))
data_layer_dict.update(self._create_data_layers(episode_end_specs))
data_layer_dict.update(
self._create_data_layers(next_episode_end_specs))
self.learn_feed_names = sorted(data_layer_dict.keys())
inputs = _select_data(data_layer_dict, input_specs)
......@@ -105,12 +106,13 @@ class ComputationTask(object):
next_states = _select_data(data_layer_dict, next_state_specs)
actions = _select_data(data_layer_dict, action_specs)
rewards = _select_data(data_layer_dict, reward_specs)
episode_end = _select_data(data_layer_dict, episode_end_specs)
next_episode_end = _select_data(data_layer_dict,
next_episode_end_specs)
## call alg learn()
### TODO: implement a recurrent layer to strip the sequence information
self.cost = self.alg.learn(inputs, next_inputs, states,
next_states, episode_end, actions,
next_states, next_episode_end, actions,
rewards)
def predict(self, inputs, states=dict()):
......@@ -149,7 +151,7 @@ class ComputationTask(object):
def learn(self,
inputs,
next_inputs,
episode_end,
next_episode_end,
actions,
rewards,
states=dict(),
......@@ -164,7 +166,7 @@ class ComputationTask(object):
data.update(next_inputs)
data.update(states)
data.update(next_states)
data.update(episode_end)
data.update(next_episode_end)
data.update(actions)
data.update(rewards)
assert sorted(data.keys()) == self.learn_feed_names, \
......
......@@ -23,8 +23,11 @@ class PolicyDistribution(object):
__metaclass__ = ABCMeta
def __init__(self, dist):
assert len(dist.shape) == 2
self.dim = dist.shape[1]
"""
self.dist represents the quantities that characterize the distribution.
For example, for a Normal distribution, this can be a tuple of (mean, std).
The actual form of self.dist is defined by the user.
"""
self.dist = dist
@abstractmethod
......@@ -34,6 +37,8 @@ class PolicyDistribution(object):
"""
pass
@property
@abstractmethod
def dim(self):
"""
For discrete policies, this function returns the number of actions.
......@@ -41,10 +46,14 @@ class PolicyDistribution(object):
For sequential policies (e.g., sentences), this function returns the number
of choices at each step.
"""
return self.dim
pass
def dist(self):
return self.dist
def add_uniform_exploration(self, rate):
"""
Given a uniform exploration rate, this function modifies the distribution.
The rate could be a floating number of a Variable.
"""
return NotImplementedError()
def loglikelihood(self, action):
"""
......@@ -57,11 +66,25 @@ class PolicyDistribution(object):
class CategoricalDistribution(PolicyDistribution):
def __init__(self, dist):
super(CategoricalDistribution, self).__init__(dist)
assert isinstance(dist, Variable)
def __call__(self):
return comf.categorical_random(self.dist)
@property
def dim(self):
assert len(self.dist.shape) == 2
return self.dist.shape[1]
def add_uniform_exploration(self, rate):
if not (isinstance(rate, float) and rate == 0):
self.dist = self.dist * (1 - rate) + \
1 / float(self.dim) * rate
def loglikelihood(self, action):
assert isinstance(action, Variable)
assert action.dtype == convert_np_dtype_to_dtype_("int") \
or action.dtype == convert_np_dtype_to_dtype_("int64")
return 0 - layers.cross_entropy(input=self.dist, label=action)
......@@ -69,21 +92,23 @@ class Deterministic(PolicyDistribution):
def __init__(self, dist):
super(Deterministic, self).__init__(dist)
## For deterministic action, we only support continuous ones
assert isinstance(dist, Variable)
assert dist.dtype == convert_np_dtype_to_dtype_("float32") \
or dist.dtype == convert_np_dtype_to_dtype_("float64")
@property
def dim(self):
assert len(self.dist.shape) == 2
return self.dist.shape[1]
def __call__(self):
return self.dist
def loglikelihood(self, action):
assert False, "You cannot compute likelihood for a deterministic action!"
def q_categorical_distribution(q_value, exploration_rate=0.0):
def q_categorical_distribution(q_value):
"""
Generate a PolicyDistribution object given a Q value.
We first construct a one-hot distribution according to the Q value,
and then add an exploration rate to get a probability.
We construct a one-hot distribution according to the Q value.
"""
assert len(q_value.shape) == 2, "[batch_size, num_actions]"
max_id = comf.argmax_layer(q_value)
......@@ -91,8 +116,4 @@ def q_categorical_distribution(q_value, exploration_rate=0.0):
x=layers.one_hot(
input=max_id, depth=q_value.shape[-1]),
dtype="float32")
### exploration_rate could be a Variable
if not (isinstance(exploration_rate, float) and exploration_rate == 0):
prob = exploration_rate / float(q_value.shape[-1]) \
+ (1 - exploration_rate) * prob
return CategoricalDistribution(prob)
......@@ -14,7 +14,7 @@
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.algorithm import Model, RLAlgorithm
from parl.framework.algorithm import Model, Algorithm
from parl.layers import common_functions as comf
from parl.model_zoo.simple_models import SimpleModelDeterministic
import numpy as np
......@@ -22,11 +22,14 @@ from copy import deepcopy
import unittest
class TestAlgorithm(RLAlgorithm):
class TestAlgorithm(Algorithm):
def __init__(self, model):
super(TestAlgorithm, self).__init__(
model, hyperparas=dict(), gpu_id=-1)
def predict(self, inputs, states):
return self._rl_predict(self.model, inputs, states)
class TestAlgorithmParas(unittest.TestCase):
def test_sync_paras_in_one_program(self):
......
......@@ -228,7 +228,7 @@ class TestComputationTask(unittest.TestCase):
cost = ct.learn(
inputs=dict(sensor=sensor),
next_inputs=dict(next_sensor=next_sensor),
episode_end=dict(episode_end=np.ones(
next_episode_end=dict(next_episode_end=np.ones(
(batch_size, 1)).astype("float32")),
actions=dict(action=actions),
rewards=dict(reward=rewards))
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.computation_task import ComputationTask
from parl.algorithm_zoo.simple_algorithms import SimpleAC, SimpleQ
from parl.model_zoo.simple_models import SimpleModelAC, SimpleModelQ
import numpy as np
import unittest
import math
import gym
def unpack_exps(exps):
return [np.array(l).astype('int' if i==2 else 'float32') \
for i, l in enumerate(zip(*exps))]
def sample(past_exps, n):
indices = np.random.choice(len(past_exps), n)
return [past_exps[i] for i in indices]
class TestGymGame(unittest.TestCase):
def test_gym_games(self):
"""
Test games in OpenAI gym.
"""
games = ["MountainCar-v0", "CartPole-v0"]
final_rewards_thresholds = [
-1.8, ## drive to the right top in 180 steps (timeout is -2.0)
1.5 ## hold the pole for at least 150 steps
]
mlp_layer_confs = [
dict(
size=128, act="relu"),
dict(
size=128, act="relu"),
dict(
size=128, act="relu"),
]
for game, threshold in zip(games, final_rewards_thresholds):
for on_policy in [False, True]:
if on_policy and game != "CartPole-v0":
## SimpleAC has difficulty training mountain-car and acrobot
continue
env = gym.make(game)
state_shape = env.observation_space.shape[0]
num_actions = env.action_space.n
if on_policy:
alg = SimpleAC(
model=SimpleModelAC(
dims=state_shape,
num_actions=num_actions,
mlp_layer_confs=mlp_layer_confs +
[dict(
size=num_actions, act="softmax")]),
hyperparas=dict(lr=1e-3))
else:
alg = SimpleQ(
model=SimpleModelQ(
dims=state_shape,
num_actions=num_actions,
mlp_layer_confs=mlp_layer_confs +
[dict(size=num_actions)]),
hyperparas=dict(lr=1e-4),
exploration_end_batches=25000,
update_ref_interval=100)
print "algorithm: " + alg.__class__.__name__
ct = ComputationTask(algorithm=alg)
batch_size = 16
if not on_policy:
train_every_steps = batch_size / 4
buffer_size_limit = 100000
max_episode = 5000
average_episode_reward = []
past_exps = []
max_steps = env._max_episode_steps
for n in range(max_episode):
ob = env.reset()
episode_reward = 0
for t in range(max_steps):
res, _ = ct.predict(inputs=dict(sensor=np.array(
[ob]).astype("float32")))
pred_action = res["action"][0][0]
next_ob, reward, next_is_over, _ = env.step(
pred_action)
reward /= 100
episode_reward += reward
past_exps.append((ob, next_ob, [pred_action],
[reward], [not next_is_over]))
## only for off-policy training we use a circular buffer
if (not on_policy
) and len(past_exps) > buffer_size_limit:
past_exps.pop(0)
## compute the learning condition
learn_cond = False
if on_policy:
learn_cond = (len(past_exps) >= batch_size)
exps = past_exps ## directly use all exps in the buffer
else:
learn_cond = (
t % train_every_steps == train_every_steps - 1)
exps = sample(past_exps,
batch_size) ## sample some exps
if learn_cond:
sensor, next_sensor, action, reward, next_episode_end \
= unpack_exps(exps)
cost = ct.learn(
inputs=dict(sensor=sensor),
next_inputs=dict(next_sensor=next_sensor),
next_episode_end=dict(
next_episode_end=next_episode_end),
actions=dict(action=action),
rewards=dict(reward=reward))
## we clear the exp buffer for on-policy
if on_policy:
past_exps = []
ob = next_ob
## end before the Gym wrongly gives game_over=True for a timeout case
if t == max_steps - 2 or next_is_over:
break
if n % 50 == 0:
print("episode reward: %f" % episode_reward)
average_episode_reward.append(episode_reward)
if len(average_episode_reward) > 20:
average_episode_reward.pop(0)
### compuare the average episode reward to reduce variance
self.assertGreater(
sum(average_episode_reward) / len(average_episode_reward),
threshold)
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,7 @@ from paddle.fluid.executor import fetch_var
import paddle.fluid as fluid
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
import paddle.fluid.layers as layers
import paddle.fluid.unique_name as unique_name
from copy import deepcopy
......@@ -533,3 +534,30 @@ def row_conv(future_context_size, param_attr=None, act=None, name=None):
def layer_norm(**kwargs):
raise NotImplementedError()
def create_persistable_variable(shape,
dtype,
name=None,
attr=None,
is_bias=False,
default_initializer=None):
"""
Return a function that creates a parameter which cannot be synchronized like those of layers
This function can be called in Algorithm, so we don't check the caller nor require that
the variable can be copied.
"""
default_name = "per_var"
attr = update_attr_name(name, default_name, attr, is_bias)
class CreateParameter_(object):
def __call__(self):
return layers.create_parameter(
shape=shape,
dtype=dtype,
attr=attr,
is_bias=is_bias,
default_initializer=default_initializer)
return CreateParameter_()
......@@ -63,17 +63,12 @@ class SimpleModelAC(Model):
class SimpleModelQ(Model):
def __init__(self,
dims,
num_actions,
mlp_layer_confs,
estimated_total_num_batches=0):
def __init__(self, dims, num_actions, mlp_layer_confs):
super(SimpleModelQ, self).__init__()
self.dims = dims
self.num_actions = num_actions
assert "act" not in mlp_layer_confs[-1], "should be linear act"
self.mlp = comf.MLP(mlp_layer_confs)
self.estimated_total_num_batches = estimated_total_num_batches
def get_input_specs(self):
return [("sensor", dict(shape=[self.dims]))]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册