test_rmsprop_op.py 20.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
18 19 20

import paddle
import paddle.fluid as fluid
21 22
import paddle.fluid.core as core
from paddle.fluid.op import Operator
S
sneaxiy 已提交
23 24


25 26 27
def create_selected_rows_and_tensor(
    scope, place, height, row_num, embedding_size
):
S
sneaxiy 已提交
28 29 30
    sr = scope.var("@selected_rows@").get_selected_rows()
    tensor = scope.var("grad").get_tensor()

31 32 33 34 35 36 37
    rows = np.random.random_integers(
        low=0,
        high=height - 1,
        size=[
            row_num,
        ],
    ).astype('int64')
S
sneaxiy 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50
    sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')

    sr.set_height(height)
    sr.set_rows(rows)
    sr.get_tensor().set(sr_val, place)

    tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
    for i in range(row_num):
        row = rows[i]
        tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]

    tensor.set(tensor_val, place)
    return tensor_val, sr_val
51 52 53


class TestBase(unittest.TestCase):
54 55 56
    def setup(
        self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
    ):
57 58
        np.random.seed(5)  # fix seed

S
sneaxiy 已提交
59 60 61
        self.scope = fluid.global_scope()
        self.place = place

62
        self.param_name = "param"
S
sneaxiy 已提交
63
        self.param = np.random.random(size).astype("float32")
64 65

        self.mean_square_name = "mean_square"
66 67 68
        self.mean_square = np.random.uniform(low=1, high=2, size=size).astype(
            "float32"
        )
69 70

        self.mean_grad_name = "mean_grad"
S
sneaxiy 已提交
71
        self.mean_grad = np.random.random(size).astype("float32")
72 73 74 75 76

        self.lr_name = "lr"
        self.learning_rate = np.array([0.01]).astype("float32")

        self.grad_name = "grad"
S
sneaxiy 已提交
77 78 79 80 81

        self.is_sparse = is_sparse
        if self.is_sparse:
            self.grad_sr_name = "@selected_rows@"
            self.grad, self.grad_sr = create_selected_rows_and_tensor(
82 83
                self.scope, place, size[0], row_num, size[1]
            )
S
sneaxiy 已提交
84 85 86 87
        else:
            self.grad = np.random.random(size).astype("float32")
            grad_tensor = self.scope.var(self.grad_name).get_tensor()
            grad_tensor.set(self.grad, place)
88 89

        self.moment_name = "moment"
90 91 92
        self.moment = np.random.uniform(low=0, high=1, size=size).astype(
            "float32"
        )
93 94 95

        self.epsilon = epsilon
        self.decay = 0.9
S
sneaxiy 已提交
96
        self.momentum = 0.1
97 98
        self.centered = centered

99 100 101 102
        self.ms_out = (
            self.decay * self.mean_square
            + (1 - self.decay) * self.grad * self.grad
        )
103
        if centered:
104 105 106 107 108 109 110 111 112
            self.mg_out = (
                self.decay * self.mean_grad + (1 - self.decay) * self.grad
            )
            self.moment_out = (
                self.momentum * self.moment
                + self.learning_rate
                * self.grad
                / np.sqrt(self.ms_out - np.square(self.mg_out) + self.epsilon)
            )
113
        else:
114 115 116 117 118 119
            self.moment_out = (
                self.momentum * self.moment
                + self.learning_rate
                * self.grad
                / np.sqrt(self.ms_out + self.epsilon)
            )
120 121 122 123

        self.param_out = self.param - self.moment_out

        # create and initialize Param Variable
S
sneaxiy 已提交
124 125
        self.param_tensor = self.scope.var(self.param_name).get_tensor()
        self.param_tensor.set(self.param, place)
126

S
sneaxiy 已提交
127
        self.mean_square_tensor = self.scope.var(
128 129
            self.mean_square_name
        ).get_tensor()
S
sneaxiy 已提交
130
        self.mean_square_tensor.set(self.mean_square, place)
131

S
sneaxiy 已提交
132
        lr = self.scope.var(self.lr_name).get_tensor()
133 134
        lr.set(self.learning_rate, place)

S
sneaxiy 已提交
135 136
        self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
        self.moment_tensor.set(self.moment, place)
137

S
sneaxiy 已提交
138 139
        if self.centered:
            self.mean_grad_tensor = self.scope.var(
140 141
                self.mean_grad_name
            ).get_tensor()
S
sneaxiy 已提交
142
            self.mean_grad_tensor.set(self.mean_grad, place)
143

S
sneaxiy 已提交
144
    def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
145 146 147 148 149
        np.testing.assert_allclose(
            actual_t,
            expect_t,
            rtol=1e-05,
            atol=atol,
150 151 152 153 154 155 156 157 158 159
            err_msg='Output ('
            + out_name
            + ') has diff at '
            + str(place)
            + '\nExpect '
            + str(expect_t)
            + '\n'
            + 'But Got'
            + str(actual_t),
        )
160

S
sneaxiy 已提交
161 162

class TestRmspropOp(TestBase):
163 164 165
    def check_with_place(
        self, place, is_sparse, centered, size, row_num=None, epsilon=1e-6
    ):
S
sneaxiy 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
        self.setup(place, is_sparse, centered, size, row_num, epsilon)
        self.run_and_check()

    def run_and_check(self):
        grad_name = self.grad_sr_name if self.is_sparse else self.grad_name

        kwargs = {
            'Param': self.param_name,
            'Grad': grad_name,
            'MeanSquare': self.mean_square_name,
            'Moment': self.moment_name,
            'LearningRate': self.lr_name,
            'ParamOut': self.param_name,
            'MeanSquareOut': self.mean_square_name,
            'MomentOut': self.moment_name,
            'epsilon': self.epsilon,
            'decay': self.decay,
            'momentum': self.momentum,
184
            'centered': self.centered,
S
sneaxiy 已提交
185
        }
186 187

        if self.centered:
S
sneaxiy 已提交
188 189 190 191 192 193 194
            kwargs['MeanGrad'] = self.mean_grad_name
            kwargs['MeanGradOut'] = self.mean_grad_name

        rmsprop_op = Operator('rmsprop', **kwargs)
        atol = 1e-6

        rmsprop_op.run(self.scope, self.place)
195

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        self.check(
            np.array(self.mean_square_tensor),
            self.ms_out,
            self.place,
            self.mean_square_name,
            atol=atol,
        )
        self.check(
            np.array(self.moment_tensor),
            self.moment_out,
            self.place,
            self.moment_name,
            atol=atol,
        )
        self.check(
            np.array(self.param_tensor),
            self.param_out,
            self.place,
            self.param_name,
            atol=atol,
        )
217 218

        if self.centered:
219 220 221 222 223 224
            self.check(
                np.array(self.mean_grad_tensor),
                self.mg_out,
                self.place,
                self.mean_grad_name,
            )
225 226 227 228 229

    def test_rmsprop(self):
        places = [core.CPUPlace()]
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
S
sneaxiy 已提交
230 231

        size = (128, 320)
232
        for place in places:
S
sneaxiy 已提交
233 234
            for centered in [False, True]:
                with fluid.scope_guard(core.Scope()):
235 236 237
                    self.check_with_place(
                        place, is_sparse=False, centered=centered, size=size
                    )
S
sneaxiy 已提交
238 239

                with fluid.scope_guard(core.Scope()):
240 241 242 243 244 245 246
                    self.check_with_place(
                        place,
                        is_sparse=True,
                        centered=centered,
                        row_num=512,
                        size=size,
                    )
S
sneaxiy 已提交
247 248

                with fluid.scope_guard(core.Scope()):
249 250 251 252 253 254 255
                    self.check_with_place(
                        place,
                        is_sparse=True,
                        centered=centered,
                        row_num=60,
                        size=size,
                    )
256 257


M
MRXLT 已提交
258 259 260 261 262
class TestRMSPropV2(unittest.TestCase):
    def test_rmsprop_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
263
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
264
        # This can be any optimizer supported by dygraph.
265 266 267 268 269
        adam = paddle.optimizer.RMSProp(
            learning_rate=0.01,
            parameters=linear.parameters(),
            weight_decay=0.01,
        )
M
MRXLT 已提交
270 271 272 273 274 275
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_rmsprop(self):
276
        paddle.enable_static()
M
MRXLT 已提交
277 278 279
        place = fluid.CPUPlace()
        main = fluid.Program()
        with fluid.program_guard(main):
G
GGBond8488 已提交
280 281
            x = paddle.static.data(name='x', shape=[-1, 13], dtype='float32')
            y = paddle.static.data(name='y', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
282
            y_predict = paddle.static.nn.fc(x, size=1)
283 284 285
            cost = paddle.nn.functional.square_error_cost(
                input=y_predict, label=y
            )
286
            avg_cost = paddle.mean(cost)
M
MRXLT 已提交
287 288 289 290 291

            rms_optimizer = paddle.optimizer.RMSProp(learning_rate=0.1)
            rms_optimizer.minimize(avg_cost)

            fetch_list = [avg_cost]
292 293 294
            train_reader = paddle.batch(
                paddle.dataset.uci_housing.train(), batch_size=1
            )
M
MRXLT 已提交
295 296 297 298 299 300 301 302
            feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            for data in train_reader():
                exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

    def test_raise_error(self):
        self.assertRaises(ValueError, paddle.optimizer.RMSProp, None)
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        self.assertRaises(
            ValueError, paddle.optimizer.RMSProp, learning_rate=0.1, rho=None
        )
        self.assertRaises(
            ValueError,
            paddle.optimizer.RMSProp,
            learning_rate=0.1,
            epsilon=None,
        )
        self.assertRaises(
            ValueError,
            paddle.optimizer.RMSProp,
            learning_rate=0.1,
            momentum=None,
        )
M
MRXLT 已提交
318

M
MRXLT 已提交
319 320 321 322
    def test_rmsprop_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
323 324 325
            adam = paddle.optimizer.RMSProp(
                0.1, epsilon=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
326
        with self.assertRaises(ValueError):
327 328 329
            adam = paddle.optimizer.RMSProp(
                0.1, momentum=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
330
        with self.assertRaises(ValueError):
331 332 333
            adam = paddle.optimizer.RMSProp(
                0.1, rho=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
334

M
MRXLT 已提交
335

336 337 338 339 340 341 342 343
class TestRMSPropV2Group(TestRMSPropV2):
    def test_rmsprop_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        # This can be any optimizer supported by dygraph.
344 345 346 347 348 349 350 351
        adam = paddle.optimizer.RMSProp(
            learning_rate=0.01,
            parameters=[
                {'params': linear_1.parameters()},
                {'params': linear_2.parameters(), 'weight_decay': 0.001},
            ],
            weight_decay=0.01,
        )
352 353 354 355 356 357 358
        out = linear_1(a)
        out = linear_2(out)
        out.backward()
        adam.step()
        adam.clear_gradients()


359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
class TestRMSOpMultiPrecison(unittest.TestCase):
    def _test_rms_op_dygraph_place_amp(self, place, use_amp=False):
        import paddle

        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device(place)

        input = paddle.randn((5, 5))

        model = paddle.nn.Linear(5, 5)

        optimizer = paddle.optimizer.RMSProp(
            learning_rate=0.01,
            parameters=model.parameters(),
            weight_decay=0.01,
        )
        optimizer._multi_precision = use_amp
        for idx in range(2):
            if place == 'gpu' and use_amp:
                model = paddle.amp.decorate(models=model, level='O2')
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

            if place == 'gpu' and use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.step(optimizer)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
        paddle.enable_static()

    def _get_places(self):
        import paddle

        places = ['cpu']
        if paddle.is_compiled_with_cuda():
            places.append('gpu')
        return places

    def test_main(self):
        for place in self._get_places():
            use_amp_list = [True, False]
            for use_amp in use_amp_list:
                self._test_rms_op_dygraph_place_amp(place, use_amp)


class TestRMSPropMultiPrecision2_0(unittest.TestCase):
    def dygraph_rmsprop_mp(self, mp, use_amp):
        paddle.disable_static()
        paddle.seed(100)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.optimizer.RMSProp(0.5, parameters=model.parameters())
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()

        return output, model.parameters()

    def static_rmsprop_mp(self, mp, use_amp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.optimizer.RMSProp(0.1)
        optimizer._multi_precision = mp

        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_rmsprop_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_rmsprop_mp(
            use_amp=False, mp=False
        )
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_rmsprop_mp(use_amp=True, mp=True)
        output2_st = self.static_rmsprop_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


class TestRMSPropMultiPrecision1_0(unittest.TestCase):
    def dygraph_rmsprop_mp(self, use_amp, mp):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.fluid.optimizer.RMSProp(
            learning_rate=0.001,
            parameter_list=model.parameters(),
        )
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_gradients()
            else:
                output = model(input)
                loss = paddle.mean(output)
                optimizer.minimize(loss)
                optimizer.clear_gradients()

        return output, model.parameters()

    def static_rmsprop_mp(self, use_amp, mp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.fluid.optimizer.RMSProp(learning_rate=0.001)
        optimizer._multi_precision = mp

        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_rmsprop_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_rmsprop_mp(
            use_amp=False, mp=False
        )
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_rmsprop_mp(use_amp=True, mp=True)
        output2_st = self.static_rmsprop_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


633
if __name__ == "__main__":
H
hong 已提交
634
    paddle.enable_static()
635
    unittest.main()