diff --git a/fluid/DeepQNetwork/DQN_agent.py b/fluid/DeepQNetwork/DQN_agent.py index 67eb3ce6a29bb723b481d6b1c2f517f037d52942..b60a31938427ec4640eca626c22f86589110cd9d 100644 --- a/fluid/DeepQNetwork/DQN_agent.py +++ b/fluid/DeepQNetwork/DQN_agent.py @@ -1,9 +1,9 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np -import math from tqdm import tqdm from utils import fluid_flatten @@ -39,34 +39,52 @@ class DQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) - best_v.stop_gradient = True + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - self._sync_program = self._build_sync_target_network() + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + self._sync_program = fluid.Program() + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -133,23 +151,6 @@ class DQNModel(object): bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) return out - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - sync_program = sync_program.prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() diff --git a/fluid/DeepQNetwork/DoubleDQN_agent.py b/fluid/DeepQNetwork/DoubleDQN_agent.py index 09b4b2119bab3fbdfa9bb9cfb8fae40fa34f87e1..5beebffe4143fa8c7f913ae6bd8ea7a9ca67e87e 100644 --- a/fluid/DeepQNetwork/DoubleDQN_agent.py +++ b/fluid/DeepQNetwork/DoubleDQN_agent.py @@ -1,11 +1,11 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np from tqdm import tqdm -import math -from utils import fluid_argmax, fluid_flatten +from utils import fluid_flatten, fluid_argmax class DoubleDQNModel(object): @@ -39,41 +39,59 @@ class DoubleDQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - next_s_predcit_value = self.get_DQN_prediction(next_s) - greedy_action = fluid_argmax(next_s_predcit_value) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim) - best_v = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value), - dim=1) - best_v.stop_gradient = True + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + next_s_predcit_value = self.get_DQN_prediction(next_s) + greedy_action = fluid_argmax(next_s_predcit_value) - self._sync_program = self._build_sync_target_network() + predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim) + best_v = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value), + dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + self._sync_program = fluid.Program() + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -140,23 +158,6 @@ class DoubleDQNModel(object): bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) return out - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - sync_program = sync_program.prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() diff --git a/fluid/DeepQNetwork/DuelingDQN_agent.py b/fluid/DeepQNetwork/DuelingDQN_agent.py index 271a767b7b5841cf1abe213fc477859e3cf5dd05..58ac9e61f4db811519f4c4fe18bf08e789df48fe 100644 --- a/fluid/DeepQNetwork/DuelingDQN_agent.py +++ b/fluid/DeepQNetwork/DuelingDQN_agent.py @@ -1,10 +1,10 @@ #-*- coding: utf-8 -*- +import math +import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -import numpy as np from tqdm import tqdm -import math from utils import fluid_flatten @@ -39,34 +39,52 @@ class DuelingDQNModel(object): name='isOver', shape=[], dtype='bool') def _build_net(self): - state, action, reward, next_s, isOver = self._get_inputs() - self.pred_value = self.get_DQN_prediction(state) - self.predict_program = fluid.default_main_program().clone() + self.predict_program = fluid.Program() + self.train_program = fluid.Program() + self._sync_program = fluid.Program() - reward = fluid.layers.clip(reward, min=-1.0, max=1.0) + with fluid.program_guard(self.predict_program): + state, action, reward, next_s, isOver = self._get_inputs() + self.pred_value = self.get_DQN_prediction(state) - action_onehot = fluid.layers.one_hot(action, self.action_dim) - action_onehot = fluid.layers.cast(action_onehot, dtype='float32') + with fluid.program_guard(self.train_program): + state, action, reward, next_s, isOver = self._get_inputs() + pred_value = self.get_DQN_prediction(state) - pred_action_value = fluid.layers.reduce_sum( - fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1) + reward = fluid.layers.clip(reward, min=-1.0, max=1.0) - targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) - best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) - best_v.stop_gradient = True + action_onehot = fluid.layers.one_hot(action, self.action_dim) + action_onehot = fluid.layers.cast(action_onehot, dtype='float32') - target = reward + (1.0 - fluid.layers.cast( - isOver, dtype='float32')) * self.gamma * best_v - cost = fluid.layers.square_error_cost(pred_action_value, target) - cost = fluid.layers.reduce_mean(cost) + pred_action_value = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1) - self._sync_program = self._build_sync_target_network() + targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) + best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) + best_v.stop_gradient = True - optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) - optimizer.minimize(cost) + target = reward + (1.0 - fluid.layers.cast( + isOver, dtype='float32')) * self.gamma * best_v + cost = fluid.layers.square_error_cost(pred_action_value, target) + cost = fluid.layers.reduce_mean(cost) - # define program - self.train_program = fluid.default_main_program() + optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) + optimizer.minimize(cost) + + vars = list(self.train_program.list_vars()) + policy_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) + target_vars = list(filter( + lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) + policy_vars.sort(key=lambda x: x.name) + target_vars.sort(key=lambda x: x.name) + + self._sync_program = fluid.Program() + with fluid.program_guard(self._sync_program): + sync_ops = [] + for i, var in enumerate(policy_vars): + sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) + sync_ops.append(sync_op) # fluid exe place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() @@ -143,24 +161,6 @@ class DuelingDQNModel(object): advantage, dim=1, keep_dim=True)) return Q - def _build_sync_target_network(self): - vars = list(fluid.default_main_program().list_vars()) - policy_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars)) - target_vars = list(filter( - lambda x: 'GRAD' not in x.name and 'target' in x.name, vars)) - policy_vars.sort(key=lambda x: x.name) - target_vars.sort(key=lambda x: x.name) - - sync_program = fluid.default_main_program().clone() - with fluid.program_guard(sync_program): - sync_ops = [] - for i, var in enumerate(policy_vars): - sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) - sync_ops.append(sync_op) - # The prune API is deprecated, please don't use it any more. - sync_program = sync_program._prune(sync_ops) - return sync_program def act(self, state, train_or_test): sample = np.random.random() @@ -186,12 +186,14 @@ class DuelingDQNModel(object): self.global_step += 1 action = np.expand_dims(action, -1) - self.exe.run(self.train_program, \ - feed={'state': state.astype('float32'), \ - 'action': action.astype('int32'), \ - 'reward': reward, \ - 'next_s': next_state.astype('float32'), \ - 'isOver': isOver}) + self.exe.run(self.train_program, + feed={ + 'state': state.astype('float32'), + 'action': action.astype('int32'), + 'reward': reward, + 'next_s': next_state.astype('float32'), + 'isOver': isOver + }) def sync_target_network(self): self.exe.run(self._sync_program) diff --git a/fluid/DeepQNetwork/README.md b/fluid/DeepQNetwork/README.md index e72920bcad29ce7ffd78bfb90a1406654298248d..1edeaaa884318ec3a530ec4fdb7d031d07411b56 100644 --- a/fluid/DeepQNetwork/README.md +++ b/fluid/DeepQNetwork/README.md @@ -29,7 +29,7 @@ The average game rewards that can be obtained for the three models as the number + gym + tqdm + opencv-python -+ paddlepaddle-gpu>=0.12.0 ++ paddlepaddle-gpu>=1.0.0 + ale_python_interface ### Install Dependencies: diff --git a/fluid/DeepQNetwork/README_cn.md b/fluid/DeepQNetwork/README_cn.md index 68a65bffe8fab79ce563fefc894dd035c1572065..640d775ad8fed2be360d308b6c5df41c86d77c04 100644 --- a/fluid/DeepQNetwork/README_cn.md +++ b/fluid/DeepQNetwork/README_cn.md @@ -28,7 +28,7 @@ + gym + tqdm + opencv-python -+ paddlepaddle-gpu>=0.12.0 ++ paddlepaddle-gpu>=1.0.0 + ale_python_interface ### 下载依赖: