test_imperative_reinforcement.py 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18
from test_imperative_base import new_program_scope
19 20 21

import paddle
import paddle.fluid as fluid
22
import paddle.nn.functional as F
23 24
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
25 26 27


class Policy(fluid.dygraph.Layer):
28
    def __init__(self, input_size):
29
        super().__init__()
30

31 32
        self.affine1 = paddle.nn.Linear(input_size, 128)
        self.affine2 = paddle.nn.Linear(128, 2)
33 34 35 36 37 38
        self.dropout_ratio = 0.6

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, inputs):
39
        x = paddle.reshape(inputs, shape=[-1, 4])
40
        x = self.affine1(x)
C
ccrrong 已提交
41
        x = paddle.nn.functional.dropout(x, self.dropout_ratio)
42
        x = F.relu(x)
43
        action_scores = self.affine2(x)
44
        return paddle.nn.functional.softmax(action_scores, axis=1)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


class TestImperativeMnist(unittest.TestCase):
    def test_mnist_float32(self):
        seed = 90
        epoch_num = 1

        state = np.random.normal(size=4).astype("float32")
        state_list = state.tolist()
        reward = np.random.random(size=[1, 1]).astype("float32")
        reward_list = reward.tolist()
        action_list = [1]
        action = np.array(action_list).astype("float32")
        mask_list = [[0, 1]]
        mask = np.array(mask_list).astype("float32")

H
hong 已提交
61
        def run_dygraph():
C
cnn 已提交
62
            paddle.seed(seed)
L
Leo Chen 已提交
63
            paddle.framework.random._manual_program_seed(seed)
64

65
            policy = Policy(input_size=4)
66 67 68 69 70 71 72 73

            dy_state = fluid.dygraph.base.to_variable(state)
            dy_state.stop_gradient = True
            loss_probs = policy(dy_state)

            dy_mask = fluid.dygraph.base.to_variable(mask)
            dy_mask.stop_gradient = True

74
            loss_probs = paddle.log(loss_probs)
75
            loss_probs = paddle.multiply(loss_probs, dy_mask)
76
            loss_probs = paddle.sum(loss_probs, axis=-1)
77 78 79 80

            dy_reward = fluid.dygraph.base.to_variable(reward)
            dy_reward.stop_gradient = True

81
            loss_probs = paddle.multiply(dy_reward, loss_probs)
82
            loss = paddle.sum(loss_probs)
83

84 85 86
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=policy.parameters()
            )
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

            dy_param_init_value = {}

            dy_out = loss.numpy()

            for param in policy.parameters():
                dy_param_init_value[param.name] = param.numpy()

            loss.backward()
            sgd.minimize(loss)
            policy.clear_gradients()

            dy_param_value = {}
            for param in policy.parameters():
                dy_param_value[param.name] = param.numpy()

H
hong 已提交
103 104 105 106 107 108
            return dy_out, dy_param_init_value, dy_param_value

        with fluid.dygraph.guard():
            dy_out, dy_param_init_value, dy_param_value = run_dygraph()

        with fluid.dygraph.guard():
109 110 111 112 113
            (
                eager_out,
                eager_param_init_value,
                eager_param_value,
            ) = run_dygraph()
H
hong 已提交
114

115
        with new_program_scope():
C
cnn 已提交
116
            paddle.seed(seed)
L
Leo Chen 已提交
117
            paddle.framework.random._manual_program_seed(seed)
118

119 120 121 122 123
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
124

125
            policy = Policy(input_size=4)
126 127 128

            st_sgd = SGDOptimizer(learning_rate=1e-3)

129 130 131 132 133 134 135 136 137
            st_state = fluid.layers.data(
                name='st_state', shape=[4], dtype='float32'
            )
            st_reward = fluid.layers.data(
                name='st_reward', shape=[1], dtype='float32'
            )
            st_mask = fluid.layers.data(
                name='st_mask', shape=[2], dtype='float32'
            )
138 139 140

            st_loss_probs = policy(st_state)

141
            st_loss_probs = paddle.log(st_loss_probs)
142
            st_loss_probs = paddle.multiply(st_loss_probs, st_mask)
143
            st_loss_probs = paddle.sum(st_loss_probs, axis=-1)
144

145
            st_loss_probs = paddle.multiply(st_reward, st_loss_probs)
146
            st_loss = paddle.sum(st_loss_probs)
147 148 149 150 151 152 153 154 155

            st_sgd.minimize(st_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
            for param in policy.parameters():
                static_param_name_list.append(param.name)

156 157 158 159
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
160 161 162 163 164 165 166

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            fetch_list = [st_loss.name]
            fetch_list.extend(static_param_name_list)

167 168 169 170 171
            out = exe.run(
                fluid.default_main_program(),
                feed={"st_state": state, "st_reward": reward, "st_mask": mask},
                fetch_list=fetch_list,
            )
172 173 174 175 176 177

            static_param_value = {}
            static_out = out[0]
            for i in range(1, len(out)):
                static_param_value[static_param_name_list[i - 1]] = out[i]

178
        # np.testing.assert_allclose(dy_x_data.all(), static_x_data.all(), rtol=1e-5)
179

180
        for key, value in static_param_init_value.items():
181 182 183 184
            self.assertTrue(np.equal(value, dy_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, dy_out).all())

185
        for key, value in static_param_value.items():
186 187
            self.assertTrue(np.equal(value, dy_param_value[key]).all())

H
hong 已提交
188
        # check eager
189
        for key, value in static_param_init_value.items():
H
hong 已提交
190 191 192 193
            self.assertTrue(np.equal(value, eager_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, eager_out).all())

194
        for key, value in static_param_value.items():
H
hong 已提交
195 196
            self.assertTrue(np.equal(value, eager_param_value[key]).all())

197 198

if __name__ == '__main__':
H
hong 已提交
199
    paddle.enable_static()
200
    unittest.main()