test_imperative_reinforcement.py 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
import paddle.fluid.dygraph.nn as nn
from test_imperative_base import new_program_scope
H
hong 已提交
24
from paddle.fluid.framework import _test_eager_guard
25 26 27


class Policy(fluid.dygraph.Layer):
28

29 30
    def __init__(self, input_size):
        super(Policy, self).__init__()
31

32 33
        self.affine1 = nn.Linear(input_size, 128)
        self.affine2 = nn.Linear(128, 2)
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
        self.dropout_ratio = 0.6

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, inputs):
        x = fluid.layers.reshape(inputs, shape=[-1, 4])
        x = self.affine1(x)
        x = fluid.layers.dropout(x, self.dropout_ratio)
        x = fluid.layers.relu(x)
        action_scores = self.affine2(x)
        return fluid.layers.softmax(action_scores, axis=1)


class TestImperativeMnist(unittest.TestCase):
49

50 51 52 53 54 55 56 57 58 59 60 61 62
    def test_mnist_float32(self):
        seed = 90
        epoch_num = 1

        state = np.random.normal(size=4).astype("float32")
        state_list = state.tolist()
        reward = np.random.random(size=[1, 1]).astype("float32")
        reward_list = reward.tolist()
        action_list = [1]
        action = np.array(action_list).astype("float32")
        mask_list = [[0, 1]]
        mask = np.array(mask_list).astype("float32")

H
hong 已提交
63
        def run_dygraph():
C
cnn 已提交
64
            paddle.seed(seed)
L
Leo Chen 已提交
65
            paddle.framework.random._manual_program_seed(seed)
66

67
            policy = Policy(input_size=4)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            dy_state = fluid.dygraph.base.to_variable(state)
            dy_state.stop_gradient = True
            loss_probs = policy(dy_state)

            dy_mask = fluid.dygraph.base.to_variable(mask)
            dy_mask.stop_gradient = True

            loss_probs = fluid.layers.log(loss_probs)
            loss_probs = fluid.layers.elementwise_mul(loss_probs, dy_mask)
            loss_probs = fluid.layers.reduce_sum(loss_probs, dim=-1)

            dy_reward = fluid.dygraph.base.to_variable(reward)
            dy_reward.stop_gradient = True

            loss_probs = fluid.layers.elementwise_mul(dy_reward, loss_probs)
            loss = fluid.layers.reduce_sum(loss_probs)

86 87
            sgd = SGDOptimizer(learning_rate=1e-3,
                               parameter_list=policy.parameters())
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

            dy_param_init_value = {}

            dy_out = loss.numpy()

            for param in policy.parameters():
                dy_param_init_value[param.name] = param.numpy()

            loss.backward()
            sgd.minimize(loss)
            policy.clear_gradients()

            dy_param_value = {}
            for param in policy.parameters():
                dy_param_value[param.name] = param.numpy()

H
hong 已提交
104 105 106 107 108 109 110 111 112 113
            return dy_out, dy_param_init_value, dy_param_value

        with fluid.dygraph.guard():
            dy_out, dy_param_init_value, dy_param_value = run_dygraph()

        with fluid.dygraph.guard():
            with _test_eager_guard():
                eager_out, eager_param_init_value, eager_param_value = run_dygraph(
                )

114
        with new_program_scope():
C
cnn 已提交
115
            paddle.seed(seed)
L
Leo Chen 已提交
116
            paddle.framework.random._manual_program_seed(seed)
117 118 119 120

            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))

121
            policy = Policy(input_size=4)
122 123 124

            st_sgd = SGDOptimizer(learning_rate=1e-3)

125 126 127 128 129 130 131 132 133
            st_state = fluid.layers.data(name='st_state',
                                         shape=[4],
                                         dtype='float32')
            st_reward = fluid.layers.data(name='st_reward',
                                          shape=[1],
                                          dtype='float32')
            st_mask = fluid.layers.data(name='st_mask',
                                        shape=[2],
                                        dtype='float32')
134 135 136 137 138 139 140

            st_loss_probs = policy(st_state)

            st_loss_probs = fluid.layers.log(st_loss_probs)
            st_loss_probs = fluid.layers.elementwise_mul(st_loss_probs, st_mask)
            st_loss_probs = fluid.layers.reduce_sum(st_loss_probs, dim=-1)

141 142
            st_loss_probs = fluid.layers.elementwise_mul(
                st_reward, st_loss_probs)
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
            st_loss = fluid.layers.reduce_sum(st_loss_probs)

            st_sgd.minimize(st_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
            for param in policy.parameters():
                static_param_name_list.append(param.name)

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            fetch_list = [st_loss.name]
            fetch_list.extend(static_param_name_list)

162 163 164 165 166 167 168
            out = exe.run(fluid.default_main_program(),
                          feed={
                              "st_state": state,
                              "st_reward": reward,
                              "st_mask": mask
                          },
                          fetch_list=fetch_list)
169 170 171 172 173 174

            static_param_value = {}
            static_out = out[0]
            for i in range(1, len(out)):
                static_param_value[static_param_name_list[i - 1]] = out[i]

175
        # np.testing.assert_allclose(dy_x_data.all(), static_x_data.all(), rtol=1e-5)
176

177
        for key, value in static_param_init_value.items():
178 179 180 181
            self.assertTrue(np.equal(value, dy_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, dy_out).all())

182
        for key, value in static_param_value.items():
183 184
            self.assertTrue(np.equal(value, dy_param_value[key]).all())

H
hong 已提交
185
        # check eager
186
        for key, value in static_param_init_value.items():
H
hong 已提交
187 188 189 190
            self.assertTrue(np.equal(value, eager_param_init_value[key]).all())

        self.assertTrue(np.equal(static_out, eager_out).all())

191
        for key, value in static_param_value.items():
H
hong 已提交
192 193
            self.assertTrue(np.equal(value, eager_param_value[key]).all())

194 195

if __name__ == '__main__':
H
hong 已提交
196
    paddle.enable_static()
197
    unittest.main()