未验证 提交 4191136f 编写于 作者: B Bo Zhou 提交者: GitHub

Merge pull request #1573 from zenghsh3/develop

make DQN models compatible with paddle>=1.00  
#-*- coding: utf-8 -*- #-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
import numpy as np
import math
from tqdm import tqdm from tqdm import tqdm
from utils import fluid_flatten
class DQNModel(object): class DQNModel(object):
...@@ -39,34 +38,51 @@ class DQNModel(object): ...@@ -39,34 +38,51 @@ class DQNModel(object):
name='isOver', shape=[], dtype='bool') name='isOver', shape=[], dtype='bool')
def _build_net(self): def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs() self.predict_program = fluid.Program()
self.pred_value = self.get_DQN_prediction(state) self.train_program = fluid.Program()
self.predict_program = fluid.default_main_program().clone() self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0) with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim) with fluid.program_guard(self.train_program):
action_onehot = fluid.layers.cast(action_onehot, dtype='float32') state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum( reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) action_onehot = fluid.layers.one_hot(action, self.action_dim)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast( pred_action_value = fluid.layers.reduce_sum(
isOver, dtype='float32')) * self.gamma * best_v fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
self._sync_program = self._build_sync_target_network() targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) target = reward + (1.0 - fluid.layers.cast(
optimizer.minimize(cost) isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
self.train_program = fluid.default_main_program() optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe # fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
...@@ -81,50 +97,50 @@ class DQNModel(object): ...@@ -81,50 +97,50 @@ class DQNModel(object):
conv1 = fluid.layers.conv2d( conv1 = fluid.layers.conv2d(
input=image, input=image,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d( max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d( conv2 = fluid.layers.conv2d(
input=max_pool1, input=max_pool1,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d( max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d( conv3 = fluid.layers.conv2d(
input=max_pool2, input=max_pool2,
num_filters=64, num_filters=64,
filter_size=[4, 4], filter_size=4,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d( max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d( conv4 = fluid.layers.conv2d(
input=max_pool3, input=max_pool3,
num_filters=64, num_filters=64,
filter_size=[3, 3], filter_size=3,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid_flatten(conv4) flatten = fluid.layers.flatten(conv4, axis=1)
out = fluid.layers.fc( out = fluid.layers.fc(
input=flatten, input=flatten,
...@@ -133,23 +149,6 @@ class DQNModel(object): ...@@ -133,23 +149,6 @@ class DQNModel(object):
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out return out
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
sync_program = sync_program.prune(sync_ops)
return sync_program
def act(self, state, train_or_test): def act(self, state, train_or_test):
sample = np.random.random() sample = np.random.random()
......
#-*- coding: utf-8 -*- #-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
import numpy as np
from tqdm import tqdm from tqdm import tqdm
import math
from utils import fluid_argmax, fluid_flatten
class DoubleDQNModel(object): class DoubleDQNModel(object):
...@@ -39,41 +38,59 @@ class DoubleDQNModel(object): ...@@ -39,41 +38,59 @@ class DoubleDQNModel(object):
name='isOver', shape=[], dtype='bool') name='isOver', shape=[], dtype='bool')
def _build_net(self): def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs() self.predict_program = fluid.Program()
self.pred_value = self.get_DQN_prediction(state) self.train_program = fluid.Program()
self.predict_program = fluid.default_main_program().clone() self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0) with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim) with fluid.program_guard(self.train_program):
action_onehot = fluid.layers.cast(action_onehot, dtype='float32') state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum( reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
next_s_predcit_value = self.get_DQN_prediction(next_s) pred_action_value = fluid.layers.reduce_sum(
greedy_action = fluid_argmax(next_s_predcit_value) fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim) targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value),
dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast( next_s_predcit_value = self.get_DQN_prediction(next_s)
isOver, dtype='float32')) * self.gamma * best_v greedy_action = fluid.layers.argmax(next_s_predcit_value, axis=1)
cost = fluid.layers.square_error_cost(pred_action_value, target) greedy_action = fluid.layers.unsqueeze(greedy_action, axes=[1])
cost = fluid.layers.reduce_mean(cost)
self._sync_program = self._build_sync_target_network() predict_onehot = fluid.layers.one_hot(greedy_action, self.action_dim)
best_v = fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(predict_onehot, targetQ_predict_value),
dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) target = reward + (1.0 - fluid.layers.cast(
optimizer.minimize(cost) isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
self.train_program = fluid.default_main_program() optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe # fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
...@@ -88,50 +105,50 @@ class DoubleDQNModel(object): ...@@ -88,50 +105,50 @@ class DoubleDQNModel(object):
conv1 = fluid.layers.conv2d( conv1 = fluid.layers.conv2d(
input=image, input=image,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d( max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d( conv2 = fluid.layers.conv2d(
input=max_pool1, input=max_pool1,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d( max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d( conv3 = fluid.layers.conv2d(
input=max_pool2, input=max_pool2,
num_filters=64, num_filters=64,
filter_size=[4, 4], filter_size=4,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d( max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d( conv4 = fluid.layers.conv2d(
input=max_pool3, input=max_pool3,
num_filters=64, num_filters=64,
filter_size=[3, 3], filter_size=3,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid_flatten(conv4) flatten = fluid.layers.flatten(conv4, axis=1)
out = fluid.layers.fc( out = fluid.layers.fc(
input=flatten, input=flatten,
...@@ -140,23 +157,6 @@ class DoubleDQNModel(object): ...@@ -140,23 +157,6 @@ class DoubleDQNModel(object):
bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
return out return out
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
sync_program = sync_program.prune(sync_ops)
return sync_program
def act(self, state, train_or_test): def act(self, state, train_or_test):
sample = np.random.random() sample = np.random.random()
......
#-*- coding: utf-8 -*- #-*- coding: utf-8 -*-
import math
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
import numpy as np
from tqdm import tqdm from tqdm import tqdm
import math
from utils import fluid_flatten
class DuelingDQNModel(object): class DuelingDQNModel(object):
...@@ -39,34 +38,51 @@ class DuelingDQNModel(object): ...@@ -39,34 +38,51 @@ class DuelingDQNModel(object):
name='isOver', shape=[], dtype='bool') name='isOver', shape=[], dtype='bool')
def _build_net(self): def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs() self.predict_program = fluid.Program()
self.pred_value = self.get_DQN_prediction(state) self.train_program = fluid.Program()
self.predict_program = fluid.default_main_program().clone() self._sync_program = fluid.Program()
reward = fluid.layers.clip(reward, min=-1.0, max=1.0) with fluid.program_guard(self.predict_program):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
action_onehot = fluid.layers.one_hot(action, self.action_dim) with fluid.program_guard(self.train_program):
action_onehot = fluid.layers.cast(action_onehot, dtype='float32') state, action, reward, next_s, isOver = self._get_inputs()
pred_value = self.get_DQN_prediction(state)
pred_action_value = fluid.layers.reduce_sum( reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True) action_onehot = fluid.layers.one_hot(action, self.action_dim)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1) action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast( pred_action_value = fluid.layers.reduce_sum(
isOver, dtype='float32')) * self.gamma * best_v fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
self._sync_program = self._build_sync_target_network() targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3) target = reward + (1.0 - fluid.layers.cast(
optimizer.minimize(cost) isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(pred_action_value, target)
cost = fluid.layers.reduce_mean(cost)
# define program optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
self.train_program = fluid.default_main_program() optimizer.minimize(cost)
vars = list(self.train_program.list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
with fluid.program_guard(self._sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# fluid exe # fluid exe
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
...@@ -81,50 +97,50 @@ class DuelingDQNModel(object): ...@@ -81,50 +97,50 @@ class DuelingDQNModel(object):
conv1 = fluid.layers.conv2d( conv1 = fluid.layers.conv2d(
input=image, input=image,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv1'.format(variable_field)), param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
max_pool1 = fluid.layers.pool2d( max_pool1 = fluid.layers.pool2d(
input=conv1, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv1, pool_size=2, pool_stride=2, pool_type='max')
conv2 = fluid.layers.conv2d( conv2 = fluid.layers.conv2d(
input=max_pool1, input=max_pool1,
num_filters=32, num_filters=32,
filter_size=[5, 5], filter_size=5,
stride=[1, 1], stride=1,
padding=[2, 2], padding=2,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv2'.format(variable_field)), param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
max_pool2 = fluid.layers.pool2d( max_pool2 = fluid.layers.pool2d(
input=conv2, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv2, pool_size=2, pool_stride=2, pool_type='max')
conv3 = fluid.layers.conv2d( conv3 = fluid.layers.conv2d(
input=max_pool2, input=max_pool2,
num_filters=64, num_filters=64,
filter_size=[4, 4], filter_size=4,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv3'.format(variable_field)), param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
max_pool3 = fluid.layers.pool2d( max_pool3 = fluid.layers.pool2d(
input=conv3, pool_size=[2, 2], pool_stride=[2, 2], pool_type='max') input=conv3, pool_size=2, pool_stride=2, pool_type='max')
conv4 = fluid.layers.conv2d( conv4 = fluid.layers.conv2d(
input=max_pool3, input=max_pool3,
num_filters=64, num_filters=64,
filter_size=[3, 3], filter_size=3,
stride=[1, 1], stride=1,
padding=[1, 1], padding=1,
act='relu', act='relu',
param_attr=ParamAttr(name='{}_conv4'.format(variable_field)), param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field))) bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))
flatten = fluid_flatten(conv4) flatten = fluid.layers.flatten(conv4, axis=1)
value = fluid.layers.fc( value = fluid.layers.fc(
input=flatten, input=flatten,
...@@ -143,24 +159,6 @@ class DuelingDQNModel(object): ...@@ -143,24 +159,6 @@ class DuelingDQNModel(object):
advantage, dim=1, keep_dim=True)) advantage, dim=1, keep_dim=True))
return Q return Q
def _build_sync_target_network(self):
vars = list(fluid.default_main_program().list_vars())
policy_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'policy' in x.name, vars))
target_vars = list(filter(
lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
policy_vars.sort(key=lambda x: x.name)
target_vars.sort(key=lambda x: x.name)
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
# The prune API is deprecated, please don't use it any more.
sync_program = sync_program._prune(sync_ops)
return sync_program
def act(self, state, train_or_test): def act(self, state, train_or_test):
sample = np.random.random() sample = np.random.random()
...@@ -186,12 +184,14 @@ class DuelingDQNModel(object): ...@@ -186,12 +184,14 @@ class DuelingDQNModel(object):
self.global_step += 1 self.global_step += 1
action = np.expand_dims(action, -1) action = np.expand_dims(action, -1)
self.exe.run(self.train_program, \ self.exe.run(self.train_program,
feed={'state': state.astype('float32'), \ feed={
'action': action.astype('int32'), \ 'state': state.astype('float32'),
'reward': reward, \ 'action': action.astype('int32'),
'next_s': next_state.astype('float32'), \ 'reward': reward,
'isOver': isOver}) 'next_s': next_state.astype('float32'),
'isOver': isOver
})
def sync_target_network(self): def sync_target_network(self):
self.exe.run(self._sync_program) self.exe.run(self._sync_program)
...@@ -29,7 +29,7 @@ The average game rewards that can be obtained for the three models as the number ...@@ -29,7 +29,7 @@ The average game rewards that can be obtained for the three models as the number
+ gym + gym
+ tqdm + tqdm
+ opencv-python + opencv-python
+ paddlepaddle-gpu>=0.12.0 + paddlepaddle-gpu>=1.0.0
+ ale_python_interface + ale_python_interface
### Install Dependencies: ### Install Dependencies:
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
+ gym + gym
+ tqdm + tqdm
+ opencv-python + opencv-python
+ paddlepaddle-gpu>=0.12.0 + paddlepaddle-gpu>=1.0.0
+ ale_python_interface + ale_python_interface
### 下载依赖: ### 下载依赖:
......
#-*- coding: utf-8 -*-
#File: utils.py
import paddle.fluid as fluid
import numpy as np
def fluid_argmax(x):
"""
Get index of max value for the last dimension
"""
_, max_index = fluid.layers.topk(x, k=1)
return max_index
def fluid_flatten(x):
"""
Flatten fluid variable along the first dimension
"""
return fluid.layers.reshape(x, shape=[-1, np.prod(x.shape[1:])])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册