DoubleDQN_agent.py 7.5 KB
Newer Older
1 2
#-*- coding: utf-8 -*-

3 4
import math
import numpy as np
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from tqdm import tqdm


class DoubleDQNModel(object):
    def __init__(self, state_dim, action_dim, gamma, hist_len, use_cuda=False):
        self.img_height = state_dim[0]
        self.img_width = state_dim[1]
        self.action_dim = action_dim
        self.gamma = gamma
        self.exploration = 1.1
        self.update_target_steps = 10000 // 4
        self.hist_len = hist_len
        self.use_cuda = use_cuda

        self.global_step = 0
        self._build_net()

    def _get_inputs(self):
        return fluid.layers.data(
                   name='state',
                   shape=[self.hist_len, self.img_height, self.img_width],
                   dtype='float32'), \
               fluid.layers.data(
                   name='action', shape=[1], dtype='int32'), \
               fluid.layers.data(
                   name='reward', shape=[], dtype='float32'), \
               fluid.layers.data(
                   name='next_s',
                   shape=[self.hist_len, self.img_height, self.img_width],
                   dtype='float32'), \
               fluid.layers.data(
                   name='isOver', shape=[], dtype='bool')

    def _build_net(self):
41 42 43
        self.predict_program = fluid.Program()
        self.train_program = fluid.Program()
        self._sync_program = fluid.Program()
44

45 46 47
        with fluid.program_guard(self.predict_program):
            state, action, reward, next_s, isOver = self._get_inputs()
            self.pred_value = self.get_DQN_prediction(state)
48

49 50 51
        with fluid.program_guard(self.train_program):
            state, action, reward, next_s, isOver = self._get_inputs()
            pred_value = self.get_DQN_prediction(state)
52

53
            reward = fluid.layers.clip(reward, min=-1.0, max=1.0)
54

55 56
            action_onehot = fluid.layers.one_hot(action, self.action_dim)
            action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
57

58 59
            pred_action_value = fluid.layers.reduce_sum(
                fluid.layers.elementwise_mul(action_onehot, pred_value), dim=1)
60

61
            targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
62

63
            next_s_predcit_value = self.get_DQN_prediction(next_s)
Z
zenghsh3 已提交
64 65
            greedy_action = fluid.layers.argmax(next_s_predcit_value, axis=1)
            greedy_action = fluid.layers.unsqueeze(greedy_action, axes=[1])
66

P
pkpk 已提交
67 68
            predict_onehot = fluid.layers.one_hot(greedy_action,
                                                  self.action_dim)
69
            best_v = fluid.layers.reduce_sum(
P
pkpk 已提交
70 71
                fluid.layers.elementwise_mul(predict_onehot,
                                             targetQ_predict_value),
72 73
                dim=1)
            best_v.stop_gradient = True
74

75 76 77 78
            target = reward + (1.0 - fluid.layers.cast(
                isOver, dtype='float32')) * self.gamma * best_v
            cost = fluid.layers.square_error_cost(pred_action_value, target)
            cost = fluid.layers.reduce_mean(cost)
79

80 81 82 83
            optimizer = fluid.optimizer.Adam(1e-3 * 0.5, epsilon=1e-3)
            optimizer.minimize(cost)

        vars = list(self.train_program.list_vars())
P
pkpk 已提交
84 85
        target_vars = list(
            filter(lambda x: 'GRAD' not in x.name and 'target' in x.name, vars))
86 87

        policy_vars_name = [
P
pkpk 已提交
88 89 90
            x.name.replace('target', 'policy') for x in target_vars
        ]
        policy_vars = list(filter(lambda x: x.name in policy_vars_name, vars))
91

92 93
        policy_vars.sort(key=lambda x: x.name)
        target_vars.sort(key=lambda x: x.name)
P
pkpk 已提交
94

95 96 97 98 99
        with fluid.program_guard(self._sync_program):
            sync_ops = []
            for i, var in enumerate(policy_vars):
                sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
                sync_ops.append(sync_op)
100 101 102 103 104 105 106 107 108 109 110 111 112 113

        # fluid exe
        place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
        self.exe = fluid.Executor(place)
        self.exe.run(fluid.default_startup_program())

    def get_DQN_prediction(self, image, target=False):
        image = image / 255.0

        variable_field = 'target' if target else 'policy'

        conv1 = fluid.layers.conv2d(
            input=image,
            num_filters=32,
Z
zenghsh3 已提交
114 115 116
            filter_size=5,
            stride=1,
            padding=2,
117 118 119 120
            act='relu',
            param_attr=ParamAttr(name='{}_conv1'.format(variable_field)),
            bias_attr=ParamAttr(name='{}_conv1_b'.format(variable_field)))
        max_pool1 = fluid.layers.pool2d(
Z
zenghsh3 已提交
121
            input=conv1, pool_size=2, pool_stride=2, pool_type='max')
122 123 124 125

        conv2 = fluid.layers.conv2d(
            input=max_pool1,
            num_filters=32,
Z
zenghsh3 已提交
126 127 128
            filter_size=5,
            stride=1,
            padding=2,
129 130 131 132
            act='relu',
            param_attr=ParamAttr(name='{}_conv2'.format(variable_field)),
            bias_attr=ParamAttr(name='{}_conv2_b'.format(variable_field)))
        max_pool2 = fluid.layers.pool2d(
Z
zenghsh3 已提交
133
            input=conv2, pool_size=2, pool_stride=2, pool_type='max')
134 135 136 137

        conv3 = fluid.layers.conv2d(
            input=max_pool2,
            num_filters=64,
Z
zenghsh3 已提交
138 139 140
            filter_size=4,
            stride=1,
            padding=1,
141 142 143 144
            act='relu',
            param_attr=ParamAttr(name='{}_conv3'.format(variable_field)),
            bias_attr=ParamAttr(name='{}_conv3_b'.format(variable_field)))
        max_pool3 = fluid.layers.pool2d(
Z
zenghsh3 已提交
145
            input=conv3, pool_size=2, pool_stride=2, pool_type='max')
146 147 148 149

        conv4 = fluid.layers.conv2d(
            input=max_pool3,
            num_filters=64,
Z
zenghsh3 已提交
150 151 152
            filter_size=3,
            stride=1,
            padding=1,
153 154 155 156
            act='relu',
            param_attr=ParamAttr(name='{}_conv4'.format(variable_field)),
            bias_attr=ParamAttr(name='{}_conv4_b'.format(variable_field)))

Z
zenghsh3 已提交
157
        flatten = fluid.layers.flatten(conv4, axis=1)
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

        out = fluid.layers.fc(
            input=flatten,
            size=self.action_dim,
            param_attr=ParamAttr(name='{}_fc1'.format(variable_field)),
            bias_attr=ParamAttr(name='{}_fc1_b'.format(variable_field)))
        return out

    def act(self, state, train_or_test):
        sample = np.random.random()
        if train_or_test == 'train' and sample < self.exploration:
            act = np.random.randint(self.action_dim)
        else:
            if np.random.random() < 0.01:
                act = np.random.randint(self.action_dim)
            else:
                state = np.expand_dims(state, axis=0)
                pred_Q = self.exe.run(self.predict_program,
                                      feed={'state': state.astype('float32')},
                                      fetch_list=[self.pred_value])[0]
                pred_Q = np.squeeze(pred_Q, axis=0)
                act = np.argmax(pred_Q)
        if train_or_test == 'train':
            self.exploration = max(0.1, self.exploration - 1e-6)
        return act

    def train(self, state, action, reward, next_state, isOver):
        if self.global_step % self.update_target_steps == 0:
            self.sync_target_network()
        self.global_step += 1

        action = np.expand_dims(action, -1)
        self.exe.run(self.train_program,
                     feed={
                         'state': state.astype('float32'),
                         'action': action.astype('int32'),
                         'reward': reward,
                         'next_s': next_state.astype('float32'),
                         'isOver': isOver
                     })

    def sync_target_network(self):
        self.exe.run(self._sync_program)