diff --git a/fluid/DeepQNetwork/DQN_agent.py b/fluid/DeepQNetwork/DQN_agent.py index 67eb3ce6a29bb723b481d6b1c2f517f037d52942..5b474325f656533b91965fd59d70c2d421e16fc3 100644 --- a/fluid/DeepQNetwork/DQN_agent.py +++ b/fluid/DeepQNetwork/DQN_agent.py @@ -1,11 +1,10 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np -import math from tqdm import tqdm -from utils import fluid_flatten class DQNModel(object): @@ -39,34 +38,51 @@ class DQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) - best_v.stop_gradient = True + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - self._sync_program = self._build_sync_target_network() + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -81,50 +97,50 @@ class DQNModel(object): conv1 = fluid.layers.conv2d( input=image, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) max_pool1 = fluid.layers.pool2d( - input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv1, pool_size=2, pool_stride=2, pool_type='max') conv2 = fluid.layers.conv2d( input=max_pool1, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) max_pool2 = fluid.layers.pool2d( - input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv2, pool_size=2, pool_stride=2, pool_type='max') conv3 = fluid.layers.conv2d( input=max_pool2, num_filters=64, - filter_size=[4, 4], - stride=[1, 1], - padding=[1, 1], + filter_size=4, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) max_pool3 = fluid.layers.pool2d( - input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv3, pool_size=2, pool_stride=2, pool_type='max') conv4 = fluid.layers.conv2d( input=max_pool3, num_filters=64, - filter_size=[3, 3], - stride=[1, 1], - padding=[1, 1], + filter_size=3, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) - flatten = fluid_flatten(conv4) + flatten = fluid.layers.flatten(conv4, axis=1) out = fluid.layers.fc( input=flatten, @@ -133,23 +149,6 @@ class DQNModel(object): bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) return out - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - sync_program = sync_program.prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() diff --git a/fluid/DeepQNetwork/DoubleDQN_agent.py b/fluid/DeepQNetwork/DoubleDQN_agent.py index 09b4b2119bab3fbdfa9bb9cfb8fae40fa34f87e1..c95ae5632fd2e904a625f680f4a9147d5615b765 100644 --- a/fluid/DeepQNetwork/DoubleDQN_agent.py +++ b/fluid/DeepQNetwork/DoubleDQN_agent.py @@ -1,11 +1,10 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np from tqdm import tqdm -import math -from utils import fluid_argmax, fluid_flatten class DoubleDQNModel(object): @@ -39,41 +38,59 @@ class DoubleDQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - next_s_predcit_value = self.get_DQN_prediction(next_s) - greedy_action = fluid_argmax(next_s_predcit_value) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim) - best_v = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value), - dim=1) - best_v.stop_gradient = True + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + next_s_predcit_value = self.get_DQN_prediction(next_s) + greedy_action = fluid.layers.argmax(next_s_predcit_value, axis=1) + greedy_action = fluid.layers.unsqueeze(greedy_action, axes=[1]) - self._sync_program = self._build_sync_target_network() + predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim) + best_v = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value), + dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -88,50 +105,50 @@ class DoubleDQNModel(object): conv1 = fluid.layers.conv2d( input=image, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) max_pool1 = fluid.layers.pool2d( - input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv1, pool_size=2, pool_stride=2, pool_type='max') conv2 = fluid.layers.conv2d( input=max_pool1, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) max_pool2 = fluid.layers.pool2d( - input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv2, pool_size=2, pool_stride=2, pool_type='max') conv3 = fluid.layers.conv2d( input=max_pool2, num_filters=64, - filter_size=[4, 4], - stride=[1, 1], - padding=[1, 1], + filter_size=4, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) max_pool3 = fluid.layers.pool2d( - input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv3, pool_size=2, pool_stride=2, pool_type='max') conv4 = fluid.layers.conv2d( input=max_pool3, num_filters=64, - filter_size=[3, 3], - stride=[1, 1], - padding=[1, 1], + filter_size=3, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) - flatten = fluid_flatten(conv4) + flatten = fluid.layers.flatten(conv4, axis=1) out = fluid.layers.fc( input=flatten, @@ -140,23 +157,6 @@ class DoubleDQNModel(object): bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) return out - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - sync_program = sync_program.prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() diff --git a/fluid/DeepQNetwork/DuelingDQN_agent.py b/fluid/DeepQNetwork/DuelingDQN_agent.py index 271a767b7b5841cf1abe213fc477859e3cf5dd05..cf2ff71bb811e5dce62be78beab1f0afb05d31f9 100644 --- a/fluid/DeepQNetwork/DuelingDQN_agent.py +++ b/fluid/DeepQNetwork/DuelingDQN_agent.py @@ -1,11 +1,10 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np from tqdm import tqdm -import math -from utils import fluid_flatten class DuelingDQNModel(object): @@ -39,34 +38,51 @@ class DuelingDQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) - best_v.stop_gradient = True + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - self._sync_program = self._build_sync_target_network() + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -81,50 +97,50 @@ class DuelingDQNModel(object): conv1 = fluid.layers.conv2d( input=image, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) max_pool1 = fluid.layers.pool2d( - input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv1, pool_size=2, pool_stride=2, pool_type='max') conv2 = fluid.layers.conv2d( input=max_pool1, num_filters=32, - filter_size=[5, 5], - stride=[1, 1], - padding=[2, 2], + filter_size=5, + stride=1, + padding=2, act='relu', param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) max_pool2 = fluid.layers.pool2d( - input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv2, pool_size=2, pool_stride=2, pool_type='max') conv3 = fluid.layers.conv2d( input=max_pool2, num_filters=64, - filter_size=[4, 4], - stride=[1, 1], - padding=[1, 1], + filter_size=4, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) max_pool3 = fluid.layers.pool2d( - input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') + input=conv3, pool_size=2, pool_stride=2, pool_type='max') conv4 = fluid.layers.conv2d( input=max_pool3, num_filters=64, - filter_size=[3, 3], - stride=[1, 1], - padding=[1, 1], + filter_size=3, + stride=1, + padding=1, act='relu', param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) - flatten = fluid_flatten(conv4) + flatten = fluid.layers.flatten(conv4, axis=1) value = fluid.layers.fc( input=flatten, @@ -143,24 +159,6 @@ class DuelingDQNModel(object): advantage, dim=1, keep_dim=True)) return Q - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - # The prune API is deprecated, please don't use it any more. - sync_program = sync_program._prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() @@ -186,12 +184,14 @@ class DuelingDQNModel(object): self.global_step += 1 action = np.expand_dims(action, -1) - self.exe.run(self.train_program, \ - feed={'state': state.astype('float32'), \ - 'action': action.astype('int32'), \ - 'reward': reward, \ - 'next_s': next_state.astype('float32'), \ - 'isOver': isOver}) + self.exe.run(self.train_program, + feed={ + 'state': state.astype('float32'), + 'action': action.astype('int32'), + 'reward': reward, + 'next_s': next_state.astype('float32'), + 'isOver': isOver + }) def sync_target_network(self): self.exe.run(self._sync_program) diff --git a/fluid/DeepQNetwork/README.md b/fluid/DeepQNetwork/README.md index e72920bcad29ce7ffd78bfb90a1406654298248d..1edeaaa884318ec3a530ec4fdb7d031d07411b56 100644 --- a/fluid/DeepQNetwork/README.md +++ b/fluid/DeepQNetwork/README.md @@ -29,7 +29,7 @@ The average game rewards that can be obtained for the three models as the number + gym + tqdm + opencv-python -+ paddlepaddle-gpu>=0.12.0 ++ paddlepaddle-gpu>=1.0.0 + ale_python_interface ### Install Dependencies: diff --git a/fluid/DeepQNetwork/README_cn.md b/fluid/DeepQNetwork/README_cn.md index 68a65bffe8fab79ce563fefc894dd035c1572065..640d775ad8fed2be360d308b6c5df41c86d77c04 100644 --- a/fluid/DeepQNetwork/README_cn.md +++ b/fluid/DeepQNetwork/README_cn.md @@ -28,7 +28,7 @@ + gym + tqdm + opencv-python -+ paddlepaddle-gpu>=0.12.0 ++ paddlepaddle-gpu>=1.0.0 + ale_python_interface ### 下载依赖: diff --git a/fluid/DeepQNetwork/utils.py b/fluid/DeepQNetwork/utils.py deleted file mode 100644 index 26ed7fbdb54494c3cf9a983f8ecafdfbcd4d2719..0000000000000000000000000000000000000000 --- a/fluid/DeepQNetwork/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -#-*- coding: utf-8 -*- -#File: utils.py - -import paddle.fluid as fluid -import numpy as np - - -def fluid_argmax(x): - """ - Get index of max value for the last dimension - """ - _, max_index = fluid.layers.topk(x, k=1) - return max_index - - -def fluid_flatten(x): - """ - Flatten fluid variable along the first dimension - """ - return fluid.layers.reshape(x, shape=[-1, np.prod(x.shape[1:])])