test_rmsprop_op.py 11.2 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
18 19
import paddle.fluid.core as core
from paddle.fluid.op import Operator
S
sneaxiy 已提交
20
import paddle.fluid as fluid
M
MRXLT 已提交
21
import paddle
S
sneaxiy 已提交
22 23


24 25 26
def create_selected_rows_and_tensor(
    scope, place, height, row_num, embedding_size
):
S
sneaxiy 已提交
27 28 29
    sr = scope.var("@selected_rows@").get_selected_rows()
    tensor = scope.var("grad").get_tensor()

30 31 32 33 34 35 36
    rows = np.random.random_integers(
        low=0,
        high=height - 1,
        size=[
            row_num,
        ],
    ).astype('int64')
S
sneaxiy 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
    sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')

    sr.set_height(height)
    sr.set_rows(rows)
    sr.get_tensor().set(sr_val, place)

    tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
    for i in range(row_num):
        row = rows[i]
        tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]

    tensor.set(tensor_val, place)
    return tensor_val, sr_val
50 51 52


class TestBase(unittest.TestCase):
53 54 55
    def setup(
        self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
    ):
56 57
        np.random.seed(5)  # fix seed

S
sneaxiy 已提交
58 59 60
        self.scope = fluid.global_scope()
        self.place = place

61
        self.param_name = "param"
S
sneaxiy 已提交
62
        self.param = np.random.random(size).astype("float32")
63 64

        self.mean_square_name = "mean_square"
65 66 67
        self.mean_square = np.random.uniform(low=1, high=2, size=size).astype(
            "float32"
        )
68 69

        self.mean_grad_name = "mean_grad"
S
sneaxiy 已提交
70
        self.mean_grad = np.random.random(size).astype("float32")
71 72 73 74 75

        self.lr_name = "lr"
        self.learning_rate = np.array([0.01]).astype("float32")

        self.grad_name = "grad"
S
sneaxiy 已提交
76 77 78 79 80

        self.is_sparse = is_sparse
        if self.is_sparse:
            self.grad_sr_name = "@selected_rows@"
            self.grad, self.grad_sr = create_selected_rows_and_tensor(
81 82
                self.scope, place, size[0], row_num, size[1]
            )
S
sneaxiy 已提交
83 84 85 86
        else:
            self.grad = np.random.random(size).astype("float32")
            grad_tensor = self.scope.var(self.grad_name).get_tensor()
            grad_tensor.set(self.grad, place)
87 88

        self.moment_name = "moment"
89 90 91
        self.moment = np.random.uniform(low=0, high=1, size=size).astype(
            "float32"
        )
92 93 94

        self.epsilon = epsilon
        self.decay = 0.9
S
sneaxiy 已提交
95
        self.momentum = 0.1
96 97
        self.centered = centered

98 99 100 101
        self.ms_out = (
            self.decay * self.mean_square
            + (1 - self.decay) * self.grad * self.grad
        )
102
        if centered:
103 104 105 106 107 108 109 110 111
            self.mg_out = (
                self.decay * self.mean_grad + (1 - self.decay) * self.grad
            )
            self.moment_out = (
                self.momentum * self.moment
                + self.learning_rate
                * self.grad
                / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
            )
112
        else:
113 114 115 116 117 118
            self.moment_out = (
                self.momentum * self.moment
                + self.learning_rate
                * self.grad
                / np.sqrt(self.ms_out + self.epsilon)
            )
119 120 121 122

        self.param_out = self.param - self.moment_out

        # create and initialize Param Variable
S
sneaxiy 已提交
123 124
        self.param_tensor = self.scope.var(self.param_name).get_tensor()
        self.param_tensor.set(self.param, place)
125

S
sneaxiy 已提交
126
        self.mean_square_tensor = self.scope.var(
127 128
            self.mean_square_name
        ).get_tensor()
S
sneaxiy 已提交
129
        self.mean_square_tensor.set(self.mean_square, place)
130

S
sneaxiy 已提交
131
        lr = self.scope.var(self.lr_name).get_tensor()
132 133
        lr.set(self.learning_rate, place)

S
sneaxiy 已提交
134 135
        self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
        self.moment_tensor.set(self.moment, place)
136

S
sneaxiy 已提交
137 138
        if self.centered:
            self.mean_grad_tensor = self.scope.var(
139 140
                self.mean_grad_name
            ).get_tensor()
S
sneaxiy 已提交
141
            self.mean_grad_tensor.set(self.mean_grad, place)
142

S
sneaxiy 已提交
143
    def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
144 145 146 147 148
        np.testing.assert_allclose(
            actual_t,
            expect_t,
            rtol=1e-05,
            atol=atol,
149 150 151 152 153 154 155 156 157 158
            err_msg='Output ('
            + out_name
            + ') has diff at '
            + str(place)
            + '\nExpect '
            + str(expect_t)
            + '\n'
            + 'But Got'
            + str(actual_t),
        )
159

S
sneaxiy 已提交
160 161

class TestRmspropOp(TestBase):
162 163 164
    def check_with_place(
        self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
    ):
S
sneaxiy 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        self.setup(place, is_sparse, centered, size, row_num, epsilon)
        self.run_and_check()

    def run_and_check(self):
        grad_name = self.grad_sr_name if self.is_sparse else self.grad_name

        kwargs = {
            'Param': self.param_name,
            'Grad': grad_name,
            'MeanSquare': self.mean_square_name,
            'Moment': self.moment_name,
            'LearningRate': self.lr_name,
            'ParamOut': self.param_name,
            'MeanSquareOut': self.mean_square_name,
            'MomentOut': self.moment_name,
            'epsilon': self.epsilon,
            'decay': self.decay,
            'momentum': self.momentum,
183
            'centered': self.centered,
S
sneaxiy 已提交
184
        }
185 186

        if self.centered:
S
sneaxiy 已提交
187 188 189 190 191 192 193
            kwargs['MeanGrad'] = self.mean_grad_name
            kwargs['MeanGradOut'] = self.mean_grad_name

        rmsprop_op = Operator('rmsprop', **kwargs)
        atol = 1e-6

        rmsprop_op.run(self.scope, self.place)
194

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        self.check(
            np.array(self.mean_square_tensor),
            self.ms_out,
            self.place,
            self.mean_square_name,
            atol=atol,
        )
        self.check(
            np.array(self.moment_tensor),
            self.moment_out,
            self.place,
            self.moment_name,
            atol=atol,
        )
        self.check(
            np.array(self.param_tensor),
            self.param_out,
            self.place,
            self.param_name,
            atol=atol,
        )
216 217

        if self.centered:
218 219 220 221 222 223
            self.check(
                np.array(self.mean_grad_tensor),
                self.mg_out,
                self.place,
                self.mean_grad_name,
            )
224 225 226 227 228

    def test_rmsprop(self):
        places = [core.CPUPlace()]
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
S
sneaxiy 已提交
229 230

        size = (128, 320)
231
        for place in places:
S
sneaxiy 已提交
232 233
            for centered in [False, True]:
                with fluid.scope_guard(core.Scope()):
234 235 236
                    self.check_with_place(
                        place, is_sparse=False, centered=centered, size=size
                    )
S
sneaxiy 已提交
237 238

                with fluid.scope_guard(core.Scope()):
239 240 241 242 243 244 245
                    self.check_with_place(
                        place,
                        is_sparse=True,
                        centered=centered,
                        row_num=512,
                        size=size,
                    )
S
sneaxiy 已提交
246 247

                with fluid.scope_guard(core.Scope()):
248 249 250 251 252 253 254
                    self.check_with_place(
                        place,
                        is_sparse=True,
                        centered=centered,
                        row_num=60,
                        size=size,
                    )
255 256


M
MRXLT 已提交
257 258 259 260 261
class TestRMSPropV2(unittest.TestCase):
    def test_rmsprop_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
262
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
263
        # This can be any optimizer supported by dygraph.
264 265 266 267 268
        adam = paddle.optimizer.RMSProp(
            learning_rate=0.01,
            parameters=linear.parameters(),
            weight_decay=0.01,
        )
M
MRXLT 已提交
269 270 271 272 273 274
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_rmsprop(self):
275
        paddle.enable_static()
M
MRXLT 已提交
276 277 278 279 280 281 282
        place = fluid.CPUPlace()
        main = fluid.Program()
        with fluid.program_guard(main):
            x = fluid.layers.data(name='x', shape=[13], dtype='float32')
            y = fluid.layers.data(name='y', shape=[1], dtype='float32')
            y_predict = fluid.layers.fc(input=x, size=1, act=None)
            cost = fluid.layers.square_error_cost(input=y_predict, label=y)
283
            avg_cost = paddle.mean(cost)
M
MRXLT 已提交
284 285 286 287 288

            rms_optimizer = paddle.optimizer.RMSProp(learning_rate=0.1)
            rms_optimizer.minimize(avg_cost)

            fetch_list = [avg_cost]
289 290 291
            train_reader = paddle.batch(
                paddle.dataset.uci_housing.train(), batch_size=1
            )
M
MRXLT 已提交
292 293 294 295 296 297 298 299
            feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            for data in train_reader():
                exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

    def test_raise_error(self):
        self.assertRaises(ValueError, paddle.optimizer.RMSProp, None)
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
        self.assertRaises(
            ValueError, paddle.optimizer.RMSProp, learning_rate=0.1, rho=None
        )
        self.assertRaises(
            ValueError,
            paddle.optimizer.RMSProp,
            learning_rate=0.1,
            epsilon=None,
        )
        self.assertRaises(
            ValueError,
            paddle.optimizer.RMSProp,
            learning_rate=0.1,
            momentum=None,
        )
M
MRXLT 已提交
315

M
MRXLT 已提交
316 317 318 319
    def test_rmsprop_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
320 321 322
            adam = paddle.optimizer.RMSProp(
                0.1, epsilon=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
323
        with self.assertRaises(ValueError):
324 325 326
            adam = paddle.optimizer.RMSProp(
                0.1, momentum=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
327
        with self.assertRaises(ValueError):
328 329 330
            adam = paddle.optimizer.RMSProp(
                0.1, rho=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
331

M
MRXLT 已提交
332

333 334 335 336 337 338 339 340
class TestRMSPropV2Group(TestRMSPropV2):
    def test_rmsprop_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        # This can be any optimizer supported by dygraph.
341 342 343 344 345 346 347 348
        adam = paddle.optimizer.RMSProp(
            learning_rate=0.01,
            parameters=[
                {'params': linear_1.parameters()},
                {'params': linear_2.parameters(), 'weight_decay': 0.001},
            ],
            weight_decay=0.01,
        )
349 350 351 352 353 354 355
        out = linear_1(a)
        out = linear_2(out)
        out.backward()
        adam.step()
        adam.clear_gradients()


356
if __name__ == "__main__":
H
hong 已提交
357
    paddle.enable_static()
358
    unittest.main()