test_imperative_reinforcement.py 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18
from test_imperative_base import new_program_scope
19 20 21

import paddle
import paddle.fluid as fluid
22 23
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
24 25 26


class Policy(fluid.dygraph.Layer):
27
    def __init__(self, input_size):
28
        super().__init__()
29

30 31
        self.affine1 = paddle.nn.Linear(input_size, 128)
        self.affine2 = paddle.nn.Linear(128, 2)
32 33 34 35 36 37
        self.dropout_ratio = 0.6

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, inputs):
38
        x = paddle.reshape(inputs, shape=[-1, 4])
39
        x = self.affine1(x)
C
ccrrong 已提交
40
        x = paddle.nn.functional.dropout(x, self.dropout_ratio)
41 42
        x = fluid.layers.relu(x)
        action_scores = self.affine2(x)
43
        return paddle.nn.functional.softmax(action_scores, axis=1)
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59


class TestImperativeMnist(unittest.TestCase):
    def test_mnist_float32(self):
        seed = 90
        epoch_num = 1

        state = np.random.normal(size=4).astype("float32")
        state_list = state.tolist()
        reward = np.random.random(size=[1, 1]).astype("float32")
        reward_list = reward.tolist()
        action_list = [1]
        action = np.array(action_list).astype("float32")
        mask_list = [[0, 1]]
        mask = np.array(mask_list).astype("float32")

H
hong 已提交
60
        def run_dygraph():
C
cnn 已提交
61
            paddle.seed(seed)
L
Leo Chen 已提交
62
            paddle.framework.random._manual_program_seed(seed)
63

64
            policy = Policy(input_size=4)
65 66 67 68 69 70 71 72

            dy_state = fluid.dygraph.base.to_variable(state)
            dy_state.stop_gradient = True
            loss_probs = policy(dy_state)

            dy_mask = fluid.dygraph.base.to_variable(mask)
            dy_mask.stop_gradient = True

73
            loss_probs = paddle.log(loss_probs)
74
            loss_probs = paddle.multiply(loss_probs, dy_mask)
75
            loss_probs = paddle.sum(loss_probs, axis=-1)
76 77 78 79

            dy_reward = fluid.dygraph.base.to_variable(reward)
            dy_reward.stop_gradient = True

80
            loss_probs = paddle.multiply(dy_reward, loss_probs)
81
            loss = paddle.sum(loss_probs)
82

83 84 85
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=policy.parameters()
            )
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

            dy_param_init_value = {}

            dy_out = loss.numpy()

            for param in policy.parameters():
                dy_param_init_value[param.name] = param.numpy()

            loss.backward()
            sgd.minimize(loss)
            policy.clear_gradients()

            dy_param_value = {}
            for param in policy.parameters():
                dy_param_value[param.name] = param.numpy()

H
hong 已提交
102 103 104 105 106 107
            return dy_out, dy_param_init_value, dy_param_value

        with fluid.dygraph.guard():
            dy_out, dy_param_init_value, dy_param_value = run_dygraph()

        with fluid.dygraph.guard():
108 109 110 111 112
            (
                eager_out,
                eager_param_init_value,
                eager_param_value,
            ) = run_dygraph()
H
hong 已提交
113

114
        with new_program_scope():
C
cnn 已提交
115
            paddle.seed(seed)
L
Leo Chen 已提交
116
            paddle.framework.random._manual_program_seed(seed)
117

118 119 120 121 122
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
123

124
            policy = Policy(input_size=4)
125 126 127

            st_sgd = SGDOptimizer(learning_rate=1e-3)

128 129 130 131 132 133 134 135 136
            st_state = fluid.layers.data(
                name='st_state', shape=[4], dtype='float32'
            )
            st_reward = fluid.layers.data(
                name='st_reward', shape=[1], dtype='float32'
            )
            st_mask = fluid.layers.data(
                name='st_mask', shape=[2], dtype='float32'
            )
137 138 139

            st_loss_probs = policy(st_state)

140
            st_loss_probs = paddle.log(st_loss_probs)
141
            st_loss_probs = paddle.multiply(st_loss_probs, st_mask)
142
            st_loss_probs = paddle.sum(st_loss_probs, axis=-1)
143

144
            st_loss_probs = paddle.multiply(st_reward, st_loss_probs)
145
            st_loss = paddle.sum(st_loss_probs)
146 147 148 149 150 151 152 153 154

            st_sgd.minimize(st_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
            for param in policy.parameters():
                static_param_name_list.append(param.name)

155 156 157 158
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
159 160 161 162 163 164 165

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            fetch_list = [st_loss.name]
            fetch_list.extend(static_param_name_list)

166 167 168 169 170
            out = exe.run(
                fluid.default_main_program(),
                feed={"st_state": state, "st_reward": reward, "st_mask": mask},
                fetch_list=fetch_list,
            )
171 172 173 174 175 176

            static_param_value = {}
            static_out = out[0]
            for i in range(1, len(out)):
                static_param_value[static_param_name_list[i - 1]] = out[i]

177
        # np.testing.assert_allclose(dy_x_data.all(), static_x_data.all(), rtol=1e-5)
178

179
        for key, value in static_param_init_value.items():
180 181 182 183
            self.assertTrue(np.equal(value, dy_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, dy_out).all())

184
        for key, value in static_param_value.items():
185 186
            self.assertTrue(np.equal(value, dy_param_value[key]).all())

H
hong 已提交
187
        # check eager
188
        for key, value in static_param_init_value.items():
H
hong 已提交
189 190 191 192
            self.assertTrue(np.equal(value, eager_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, eager_out).all())

193
        for key, value in static_param_value.items():
H
hong 已提交
194 195
            self.assertTrue(np.equal(value, eager_param_value[key]).all())

196 197

if __name__ == '__main__':
H
hong 已提交
198
    paddle.enable_static()
199
    unittest.main()