提交 84806a5e 编写于 作者: Z zenghsh3

make DeepQNetwork models support paddlepaddle>=1.0.0

上级 3b4eb996
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
import math
from tqdm import tqdm
from utils import fluid_flatten
......@@ -39,34 +39,52 @@ class DQNModel(object):
name='isOver', shape=[], dtype='bool')
def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
self.predict_program = fluid.default_main_program().clone()
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
self._sync_program = self._build_sync_target_network()
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program
self.train_program = fluid.default_main_program()
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
self._sync_program = fluid.Program()
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
......@@ -133,23 +151,6 @@ class DQNModel(object):
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
sync_program = sync_program.prune(sync_ops)
return sync_program
def act(self, state, train_or_test):
sample = np.random.random()
......
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
from tqdm import tqdm
import math
from utils import fluid_argmax, fluid_flatten
from utils import fluid_flatten, fluid_argmax
class DoubleDQNModel(object):
......@@ -39,41 +39,59 @@ class DoubleDQNModel(object):
name='isOver', shape=[], dtype='bool')
def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
self.predict_program = fluid.default_main_program().clone()
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
next_s_predcit_value = self.get_DQN_prediction(next_s)
greedy_action = fluid_argmax(next_s_predcit_value)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim)
best_v = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value),
dim=1)
best_v.stop_gradient = True
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
next_s_predcit_value = self.get_DQN_prediction(next_s)
greedy_action = fluid_argmax(next_s_predcit_value)
self._sync_program = self._build_sync_target_network()
predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim)
best_v = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value),
dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program
self.train_program = fluid.default_main_program()
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
self._sync_program = fluid.Program()
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
......@@ -140,23 +158,6 @@ class DoubleDQNModel(object):
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
sync_program = sync_program.prune(sync_ops)
return sync_program
def act(self, state, train_or_test):
sample = np.random.random()
......
#-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
from tqdm import tqdm
import math
from utils import fluid_flatten
......@@ -39,34 +39,52 @@ class DuelingDQNModel(object):
name='isOver', shape=[], dtype='bool')
def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
self.predict_program = fluid.default_main_program().clone()
self.predict_program = fluid.Program()
self.train_program = fluid.Program()
self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
with fluid.program_guard(self.train_program):
state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
pred_action_value = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
self._sync_program = self._build_sync_target_network()
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
target = reward + (1.0 - fluid.layers.cast(
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program
self.train_program = fluid.default_main_program()
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
self._sync_program = fluid.Program()
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
......@@ -143,24 +161,6 @@ class DuelingDQNModel(object):
advantage, dim=1, keep_dim=True))
return Q
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# The prune API is deprecated, please don't use it any more.
sync_program = sync_program._prune(sync_ops)
return sync_program
def act(self, state, train_or_test):
sample = np.random.random()
......@@ -186,12 +186,14 @@ class DuelingDQNModel(object):
self.global_step += 1
action = np.expand_dims(action, -1)
self.exe.run(self.train_program, \
feed={'state': state.astype('float32'), \
'action': action.astype('int32'), \
'reward': reward, \
'next_s': next_state.astype('float32'), \
'isOver': isOver})
self.exe.run(self.train_program,
feed={
'state': state.astype('float32'),
'action': action.astype('int32'),
'reward': reward,
'next_s': next_state.astype('float32'),
'isOver': isOver
})
def sync_target_network(self):
self.exe.run(self._sync_program)
......@@ -29,7 +29,7 @@ The average game rewards that can be obtained for the three models as the number
+ gym
+ tqdm
+ opencv-python
+ paddlepaddle-gpu>=0.12.0
+ paddlepaddle-gpu>=1.0.0
+ ale_python_interface
### Install Dependencies:
......
......@@ -28,7 +28,7 @@
+ gym
+ tqdm
+ opencv-python
+ paddlepaddle-gpu>=0.12.0
+ paddlepaddle-gpu>=1.0.0
+ ale_python_interface
### 下载依赖:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册