diff --git a/parl/algorithm_zoo/simple_algorithms.py b/parl/algorithm_zoo/simple_algorithms.py index a0040683401e25b7e3867afd7df90475e0b1e53d..bf8c5bc6cb4f1461db4d266fa1190a0f68533642 100644 --- a/parl/algorithm_zoo/simple_algorithms.py +++ b/parl/algorithm_zoo/simple_algorithms.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parl.framework.algorithm import RLAlgorithm +from parl.framework.algorithm import Algorithm +from paddle.fluid.initializer import ConstantInitializer import parl.layers as layers import parl.framework.policy_distribution as pd from parl.layers import common_functions as comf @@ -20,7 +21,7 @@ import paddle.fluid as fluid from copy import deepcopy -class SimpleAC(RLAlgorithm): +class SimpleAC(Algorithm): """ A simple Actor-Critic that has a feedforward policy network and a single discrete action. @@ -37,7 +38,7 @@ class SimpleAC(RLAlgorithm): super(SimpleAC, self).__init__(model, hyperparas, gpu_id) self.discount_factor = discount_factor - def learn(self, inputs, next_inputs, states, next_states, episode_end, + def learn(self, inputs, next_inputs, states, next_states, next_episode_end, actions, rewards): action = actions["action"] @@ -46,7 +47,8 @@ class SimpleAC(RLAlgorithm): values = self.model.value(inputs, states) next_values = self.model.value(next_inputs, next_states) value = values["v_value"] - next_value = next_values["v_value"] * episode_end["episode_end"] + next_value = next_values["v_value"] * next_episode_end[ + "next_episode_end"] next_value.stop_gradient = True assert value.shape[1] == next_value.shape[1] @@ -60,12 +62,16 @@ class SimpleAC(RLAlgorithm): pg_cost = 0 - dist.loglikelihood(action) avg_cost = layers.mean(x=value_cost + pg_cost * td_error) - optimizer = fluid.optimizer.SGD(learning_rate=self.hp["lr"]) + optimizer = fluid.optimizer.DecayedAdagradOptimizer( + learning_rate=self.hp["lr"]) optimizer.minimize(avg_cost) return dict(cost=avg_cost) + def predict(self, inputs, states): + return self._rl_predict(self.model, inputs, states) -class SimpleQ(RLAlgorithm): + +class SimpleQ(Algorithm): """ A simple Q-learning that has a feedforward policy network and a single discrete action. @@ -77,6 +83,8 @@ class SimpleQ(RLAlgorithm): hyperparas=dict(lr=1e-4), gpu_id=-1, discount_factor=0.99, + exploration_end_batches=0, + exploration_end_rate=0.1, update_ref_interval=100): super(SimpleQ, self).__init__(model, hyperparas, gpu_id) @@ -87,13 +95,45 @@ class SimpleQ(RLAlgorithm): self.total_batches = 0 ## create a reference model self.ref_model = deepcopy(model) + ## setup exploration + self.explore = (exploration_end_batches > 0) + if self.explore: + self.exploration_counter = layers.create_persistable_variable( + dtype="float32", + shape=[1], + is_bias=True, + default_initializer=ConstantInitializer(0.)) + ### in the second half of training time, the rate is fixed to a number + self.total_exploration_batches = exploration_end_batches + self.exploration_rate_delta \ + = (1 - exploration_end_rate) / self.total_exploration_batches def before_every_batch(self): if self.total_batches % self.update_ref_interval == 0: self.model.sync_paras_to(self.ref_model, self.gpu_id) self.total_batches += 1 - def learn(self, inputs, next_inputs, states, next_states, episode_end, + def predict(self, inputs, states): + """ + Override the base predict() function to put the exploration rate in inputs + """ + rate = 0 + if self.explore: + counter = self.exploration_counter() + ## first compute the current exploration rate + rate = 1 - counter * self.exploration_rate_delta + + distributions, states = self.model.policy(inputs, states) + for dist in distributions.values(): + assert dist.__class__.__name__ == "CategoricalDistribution" + dist.add_uniform_exploration(rate) + + actions = {} + for key, dist in distributions.iteritems(): + actions[key] = dist() + return actions, states + + def learn(self, inputs, next_inputs, states, next_states, next_episode_end, actions, rewards): action = actions["action"] @@ -102,7 +142,8 @@ class SimpleQ(RLAlgorithm): values = self.model.value(inputs, states) next_values = self.ref_model.value(next_inputs, next_states) q_value = values["q_value"] - next_q_value = next_values["q_value"] * episode_end["episode_end"] + next_q_value = next_values["q_value"] * next_episode_end[ + "next_episode_end"] next_q_value.stop_gradient = True next_value = layers.reduce_max(next_q_value, dim=-1) assert q_value.shape[1] == next_q_value.shape[1] @@ -113,6 +154,20 @@ class SimpleQ(RLAlgorithm): td_error = critic_value - value avg_cost = layers.mean(x=layers.square(td_error)) - optimizer = fluid.optimizer.SGD(learning_rate=self.hp["lr"]) + optimizer = fluid.optimizer.DecayedAdagradOptimizer( + learning_rate=self.hp["lr"]) optimizer.minimize(avg_cost) + + self._increment_exploration_counter() return dict(cost=avg_cost) + + def _increment_exploration_counter(self): + if self.explore: + counter = self.exploration_counter() + exploration_counter_ = counter + 1 + switch = layers.cast( + x=(exploration_counter_ > self.total_exploration_batches), + dtype="float32") + ## if the counter already hits the limit, we do not change the counter + layers.assign(switch * counter + + (1 - switch) * exploration_counter_, counter) diff --git a/parl/framework/algorithm.py b/parl/framework/algorithm.py index c0318efd2a93688311798f48e7253b6aa5477774..ba291df5a8dfd28fa65688c9e2064ff292f56db5 100644 --- a/parl/framework/algorithm.py +++ b/parl/framework/algorithm.py @@ -35,7 +35,7 @@ def check_duplicate_spec_names(model): class Model(Network): """ - A Model is owned by an Algorithm. It implements all the network model of + A Model is owned by an Algorithm. It implements the entire network model of a specific problem. """ __metaclass__ = ABCMeta @@ -142,11 +142,28 @@ class Algorithm(object): def predict(self, inputs, states): """ Given the inputs and states, this function does forward prediction and updates states. + Input: inputs(dict), states(dict) + Output: actions(dict), states(dict) + Optional: an algorithm might not implement predict() """ pass - def learn(self, inputs, next_inputs, states, next_states, episode_end, + def _rl_predict(self, behavior_model, inputs, states): + """ + Given a behavior model (not necessarily equal to self.model), this function + performs a normal RL prediction according to inputs and states. + A behavior model different from self.model indicates off-policy training. + + The user can choose to call this function for convenience. + """ + distributions, states = behavior_model.policy(inputs, states) + actions = {} + for key, dist in distributions.iteritems(): + actions[key] = dist() + return actions, states + + def learn(self, inputs, next_inputs, states, next_states, next_episode_end, actions, rewards): """ This function computes a learning cost to be optimized. @@ -156,40 +173,3 @@ class Algorithm(object): Optional: an algorithm might not implement learn() """ pass - - -class RLAlgorithm(Algorithm): - """ - A derived Algorithm class specially for RL problems. - """ - - def __init__(self, model, hyperparas, gpu_id): - super(RLAlgorithm, self).__init__(model, hyperparas, gpu_id) - - def get_behavior_model(self): - """ - Return the behavior model to compute actions. The behavior model could be different - from the training model, which is common in off-policy RL algorithms. - - The default behavior model is set to the training model. The user can override this - function to specify another different model. - """ - return self.model - - def predict(self, inputs, states): - """ - Implementation of Algorithm.predict() - - Given the inputs and states, this function predicts actions and updates states. - Input: inputs(dict), states(dict) - Output: actions(dict), states(dict) - """ - behavior_model = self.get_behavior_model() - distributions, states = behavior_model.policy(inputs, states) - actions = {} - for key, dist in distributions.iteritems(): - assert isinstance( - dist, pd.PolicyDistribution - ), "behavior_model.policy must return PolicyDist!" - actions[key] = dist() - return actions, states diff --git a/parl/framework/computation_task.py b/parl/framework/computation_task.py index af8cdbfb9e2baf0e8f38967563371086dac7ba24..5fab538bd97edc87736288de22ebe65fdfaa69fa 100644 --- a/parl/framework/computation_task.py +++ b/parl/framework/computation_task.py @@ -72,7 +72,7 @@ class ComputationTask(object): next_state_specs = _get_next_specs(state_specs) action_specs = self.alg.get_action_specs() reward_specs = self.alg.get_reward_specs() - episode_end_specs = [("episode_end", dict(shape=[1]))] + next_episode_end_specs = [("next_episode_end", dict(shape=[1]))] self.action_names = sorted([name for name, _ in action_specs]) self.state_names = sorted([name for name, _ in state_specs]) @@ -96,7 +96,8 @@ class ComputationTask(object): data_layer_dict.update(self._create_data_layers(next_state_specs)) data_layer_dict.update(self._create_data_layers(action_specs)) data_layer_dict.update(self._create_data_layers(reward_specs)) - data_layer_dict.update(self._create_data_layers(episode_end_specs)) + data_layer_dict.update( + self._create_data_layers(next_episode_end_specs)) self.learn_feed_names = sorted(data_layer_dict.keys()) inputs = _select_data(data_layer_dict, input_specs) @@ -105,12 +106,13 @@ class ComputationTask(object): next_states = _select_data(data_layer_dict, next_state_specs) actions = _select_data(data_layer_dict, action_specs) rewards = _select_data(data_layer_dict, reward_specs) - episode_end = _select_data(data_layer_dict, episode_end_specs) + next_episode_end = _select_data(data_layer_dict, + next_episode_end_specs) ## call alg learn() ### TODO: implement a recurrent layer to strip the sequence information self.cost = self.alg.learn(inputs, next_inputs, states, - next_states, episode_end, actions, + next_states, next_episode_end, actions, rewards) def predict(self, inputs, states=dict()): @@ -149,7 +151,7 @@ class ComputationTask(object): def learn(self, inputs, next_inputs, - episode_end, + next_episode_end, actions, rewards, states=dict(), @@ -164,7 +166,7 @@ class ComputationTask(object): data.update(next_inputs) data.update(states) data.update(next_states) - data.update(episode_end) + data.update(next_episode_end) data.update(actions) data.update(rewards) assert sorted(data.keys()) == self.learn_feed_names, \ diff --git a/parl/framework/policy_distribution.py b/parl/framework/policy_distribution.py index a45c0966aa570f5138ed8c2c22e334481fc786ce..35d682e2d96b89f7aad7ad4e11b5c4cea82d92bb 100644 --- a/parl/framework/policy_distribution.py +++ b/parl/framework/policy_distribution.py @@ -23,8 +23,11 @@ class PolicyDistribution(object): __metaclass__ = ABCMeta def __init__(self, dist): - assert len(dist.shape) == 2 - self.dim = dist.shape[1] + """ + self.dist represents the quantities that characterize the distribution. + For example, for a Normal distribution, this can be a tuple of (mean, std). + The actual form of self.dist is defined by the user. + """ self.dist = dist @abstractmethod @@ -34,6 +37,8 @@ class PolicyDistribution(object): """ pass + @property + @abstractmethod def dim(self): """ For discrete policies, this function returns the number of actions. @@ -41,10 +46,14 @@ class PolicyDistribution(object): For sequential policies (e.g., sentences), this function returns the number of choices at each step. """ - return self.dim + pass - def dist(self): - return self.dist + def add_uniform_exploration(self, rate): + """ + Given a uniform exploration rate, this function modifies the distribution. + The rate could be a floating number of a Variable. + """ + return NotImplementedError() def loglikelihood(self, action): """ @@ -57,11 +66,25 @@ class PolicyDistribution(object): class CategoricalDistribution(PolicyDistribution): def __init__(self, dist): super(CategoricalDistribution, self).__init__(dist) + assert isinstance(dist, Variable) def __call__(self): return comf.categorical_random(self.dist) + @property + def dim(self): + assert len(self.dist.shape) == 2 + return self.dist.shape[1] + + def add_uniform_exploration(self, rate): + if not (isinstance(rate, float) and rate == 0): + self.dist = self.dist * (1 - rate) + \ + 1 / float(self.dim) * rate + def loglikelihood(self, action): + assert isinstance(action, Variable) + assert action.dtype == convert_np_dtype_to_dtype_("int") \ + or action.dtype == convert_np_dtype_to_dtype_("int64") return 0 - layers.cross_entropy(input=self.dist, label=action) @@ -69,21 +92,23 @@ class Deterministic(PolicyDistribution): def __init__(self, dist): super(Deterministic, self).__init__(dist) ## For deterministic action, we only support continuous ones + assert isinstance(dist, Variable) assert dist.dtype == convert_np_dtype_to_dtype_("float32") \ or dist.dtype == convert_np_dtype_to_dtype_("float64") + @property + def dim(self): + assert len(self.dist.shape) == 2 + return self.dist.shape[1] + def __call__(self): return self.dist - def loglikelihood(self, action): - assert False, "You cannot compute likelihood for a deterministic action!" - -def q_categorical_distribution(q_value, exploration_rate=0.0): +def q_categorical_distribution(q_value): """ Generate a PolicyDistribution object given a Q value. - We first construct a one-hot distribution according to the Q value, - and then add an exploration rate to get a probability. + We construct a one-hot distribution according to the Q value. """ assert len(q_value.shape) == 2, "[batch_size, num_actions]" max_id = comf.argmax_layer(q_value) @@ -91,8 +116,4 @@ def q_categorical_distribution(q_value, exploration_rate=0.0): x=layers.one_hot( input=max_id, depth=q_value.shape[-1]), dtype="float32") - ### exploration_rate could be a Variable - if not (isinstance(exploration_rate, float) and exploration_rate == 0): - prob = exploration_rate / float(q_value.shape[-1]) \ - + (1 - exploration_rate) * prob return CategoricalDistribution(prob) diff --git a/parl/framework/tests/test_algorithm.py b/parl/framework/tests/test_algorithm.py index cdae4b44a277d014e9237aeb97e2b292e20c5c08..595eecf812061babbb076f5839bc62bc974e943e 100644 --- a/parl/framework/tests/test_algorithm.py +++ b/parl/framework/tests/test_algorithm.py @@ -14,7 +14,7 @@ import paddle.fluid as fluid import parl.layers as layers -from parl.framework.algorithm import Model, RLAlgorithm +from parl.framework.algorithm import Model, Algorithm from parl.layers import common_functions as comf from parl.model_zoo.simple_models import SimpleModelDeterministic import numpy as np @@ -22,11 +22,14 @@ from copy import deepcopy import unittest -class TestAlgorithm(RLAlgorithm): +class TestAlgorithm(Algorithm): def __init__(self, model): super(TestAlgorithm, self).__init__( model, hyperparas=dict(), gpu_id=-1) + def predict(self, inputs, states): + return self._rl_predict(self.model, inputs, states) + class TestAlgorithmParas(unittest.TestCase): def test_sync_paras_in_one_program(self): diff --git a/parl/framework/tests/test_computation_task.py b/parl/framework/tests/test_computation_task.py index 8311e2163e08d6034145cbe934e696fd41d72e9a..470d2ba09cbe3980fe78a09d7302a64e9acebec6 100644 --- a/parl/framework/tests/test_computation_task.py +++ b/parl/framework/tests/test_computation_task.py @@ -228,7 +228,7 @@ class TestComputationTask(unittest.TestCase): cost = ct.learn( inputs=dict(sensor=sensor), next_inputs=dict(next_sensor=next_sensor), - episode_end=dict(episode_end=np.ones( + next_episode_end=dict(next_episode_end=np.ones( (batch_size, 1)).astype("float32")), actions=dict(action=actions), rewards=dict(reward=rewards)) diff --git a/parl/framework/tests/test_simple_games.py b/parl/framework/tests/test_simple_games.py new file mode 100644 index 0000000000000000000000000000000000000000..c37c69d149177a127e21c2d55c2441efe94d693b --- /dev/null +++ b/parl/framework/tests/test_simple_games.py @@ -0,0 +1,166 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.computation_task import ComputationTask +from parl.algorithm_zoo.simple_algorithms import SimpleAC, SimpleQ +from parl.model_zoo.simple_models import SimpleModelAC, SimpleModelQ +import numpy as np +import unittest +import math +import gym + + +def unpack_exps(exps): + return [np.array(l).astype('int' if i==2 else 'float32') \ + for i, l in enumerate(zip(*exps))] + + +def sample(past_exps, n): + indices = np.random.choice(len(past_exps), n) + return [past_exps[i] for i in indices] + + +class TestGymGame(unittest.TestCase): + def test_gym_games(self): + """ + Test games in OpenAI gym. + """ + + games = ["MountainCar-v0", "CartPole-v0"] + final_rewards_thresholds = [ + -1.8, ## drive to the right top in 180 steps (timeout is -2.0) + 1.5 ## hold the pole for at least 150 steps + ] + + mlp_layer_confs = [ + dict( + size=128, act="relu"), + dict( + size=128, act="relu"), + dict( + size=128, act="relu"), + ] + + for game, threshold in zip(games, final_rewards_thresholds): + for on_policy in [False, True]: + + if on_policy and game != "CartPole-v0": + ## SimpleAC has difficulty training mountain-car and acrobot + continue + + env = gym.make(game) + state_shape = env.observation_space.shape[0] + num_actions = env.action_space.n + + if on_policy: + alg = SimpleAC( + model=SimpleModelAC( + dims=state_shape, + num_actions=num_actions, + mlp_layer_confs=mlp_layer_confs + + [dict( + size=num_actions, act="softmax")]), + hyperparas=dict(lr=1e-3)) + else: + alg = SimpleQ( + model=SimpleModelQ( + dims=state_shape, + num_actions=num_actions, + mlp_layer_confs=mlp_layer_confs + + [dict(size=num_actions)]), + hyperparas=dict(lr=1e-4), + exploration_end_batches=25000, + update_ref_interval=100) + + print "algorithm: " + alg.__class__.__name__ + + ct = ComputationTask(algorithm=alg) + batch_size = 16 + if not on_policy: + train_every_steps = batch_size / 4 + buffer_size_limit = 100000 + + max_episode = 5000 + + average_episode_reward = [] + past_exps = [] + max_steps = env._max_episode_steps + for n in range(max_episode): + ob = env.reset() + episode_reward = 0 + for t in range(max_steps): + res, _ = ct.predict(inputs=dict(sensor=np.array( + [ob]).astype("float32"))) + pred_action = res["action"][0][0] + + next_ob, reward, next_is_over, _ = env.step( + pred_action) + reward /= 100 + episode_reward += reward + + past_exps.append((ob, next_ob, [pred_action], + [reward], [not next_is_over])) + ## only for off-policy training we use a circular buffer + if (not on_policy + ) and len(past_exps) > buffer_size_limit: + past_exps.pop(0) + + ## compute the learning condition + learn_cond = False + if on_policy: + learn_cond = (len(past_exps) >= batch_size) + exps = past_exps ## directly use all exps in the buffer + else: + learn_cond = ( + t % train_every_steps == train_every_steps - 1) + exps = sample(past_exps, + batch_size) ## sample some exps + + if learn_cond: + sensor, next_sensor, action, reward, next_episode_end \ + = unpack_exps(exps) + cost = ct.learn( + inputs=dict(sensor=sensor), + next_inputs=dict(next_sensor=next_sensor), + next_episode_end=dict( + next_episode_end=next_episode_end), + actions=dict(action=action), + rewards=dict(reward=reward)) + ## we clear the exp buffer for on-policy + if on_policy: + past_exps = [] + + ob = next_ob + + ## end before the Gym wrongly gives game_over=True for a timeout case + if t == max_steps - 2 or next_is_over: + break + + if n % 50 == 0: + print("episode reward: %f" % episode_reward) + + average_episode_reward.append(episode_reward) + if len(average_episode_reward) > 20: + average_episode_reward.pop(0) + + ### compuare the average episode reward to reduce variance + self.assertGreater( + sum(average_episode_reward) / len(average_episode_reward), + threshold) + + +if __name__ == "__main__": + unittest.main() diff --git a/parl/layers/layer_wrappers.py b/parl/layers/layer_wrappers.py index 41c100015764f744aff66ecbf49b99ebe112ec0a..0a8959f9d543e9596ad6b84901f608954becc221 100644 --- a/parl/layers/layer_wrappers.py +++ b/parl/layers/layer_wrappers.py @@ -19,6 +19,7 @@ from paddle.fluid.executor import fetch_var import paddle.fluid as fluid from paddle.fluid.layers import * from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import Variable import paddle.fluid.layers as layers import paddle.fluid.unique_name as unique_name from copy import deepcopy @@ -533,3 +534,30 @@ def row_conv(future_context_size, param_attr=None, act=None, name=None): def layer_norm(**kwargs): raise NotImplementedError() + + +def create_persistable_variable(shape, + dtype, + name=None, + attr=None, + is_bias=False, + default_initializer=None): + """ + Return a function that creates a parameter which cannot be synchronized like those of layers + + This function can be called in Algorithm, so we don't check the caller nor require that + the variable can be copied. + """ + default_name = "per_var" + attr = update_attr_name(name, default_name, attr, is_bias) + + class CreateParameter_(object): + def __call__(self): + return layers.create_parameter( + shape=shape, + dtype=dtype, + attr=attr, + is_bias=is_bias, + default_initializer=default_initializer) + + return CreateParameter_() diff --git a/parl/model_zoo/simple_models.py b/parl/model_zoo/simple_models.py index 13d700742f0d82487ddcec6dd65b2f67594eb0cd..2a4b10fa3db166064762659cde42d2c390d657c0 100644 --- a/parl/model_zoo/simple_models.py +++ b/parl/model_zoo/simple_models.py @@ -63,17 +63,12 @@ class SimpleModelAC(Model): class SimpleModelQ(Model): - def __init__(self, - dims, - num_actions, - mlp_layer_confs, - estimated_total_num_batches=0): + def __init__(self, dims, num_actions, mlp_layer_confs): super(SimpleModelQ, self).__init__() self.dims = dims self.num_actions = num_actions assert "act" not in mlp_layer_confs[-1], "should be linear act" self.mlp = comf.MLP(mlp_layer_confs) - self.estimated_total_num_batches = estimated_total_num_batches def get_input_specs(self): return [("sensor", dict(shape=[self.dims]))]