test_adamw_op.py 23.9 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
Z
zhaoyingli 已提交
17
import random
M
MRXLT 已提交
18 19
import numpy as np
import paddle.fluid as fluid
Z
zhaoyingli 已提交
20
from op_test import OpTest
21
from functools import partial
Z
zhaoyingli 已提交
22
from paddle.framework import core
C
chentianyu03 已提交
23
from paddle.fluid.framework import _test_eager_guard
Z
zhaoyingli 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56


def adamw_step(inputs, attributes):
    param = inputs['Param']
    grad = inputs['Grad']
    moment1 = inputs['Moment1']
    moment2 = inputs['Moment2']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']
    beta2_pow = inputs['Beta2Pow']

    epsilon = attributes['epsilon']

    if 'lr_ratio' in attributes:
        lr = lr * attributes['lr_ratio']

    if attributes["with_decay"]:
        coeff = attributes["coeff"]
        decay = 1.0 - lr * coeff
        param2 = param * decay
        param = param2.copy()

    if 'beta1' in attributes:
        beta1 = attributes['beta1']
    else:
        beta1 = inputs['Beta1Tensor'][0]
    if 'beta2' in attributes:
        beta2 = attributes['beta2']
    else:
        beta2 = inputs['Beta2Tensor'][0]

    moment1_out = beta1 * moment1 + (1 - beta1) * grad
    moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
Z
zhaoyingli 已提交
57 58
    denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon
    param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow))))
Z
zhaoyingli 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    return param_out, moment1_out, moment2_out


class TestAdamW(OpTest):
    def setUp(self):
        '''Test AdamW Op with supplied attributes
        '''
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "coeff": 0.5,
            "with_decay": True
        }

        param_out, moment1_out, \
            moment2_out = adamw_step(self.inputs, self.attrs)

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
        }

    def test_check_output(self):
        self.check_output()


@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestAdamW2(OpTest):
    def setUp(self):
        '''Test AdamW Op with supplied attributes
        '''
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        grad = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((2, 2)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "lr_ratio": 0.1,
            "coeff": 0.5,
            "with_decay": True
        }

        param_out, moment1_out, moment2_out = adamw_step(self.inputs,
                                                         self.attrs)

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
        }

    def test_check_output(self):
        self.check_output_with_place(core.CUDAPlace(0))
M
MRXLT 已提交
165 166 167 168 169 170


class TestAdamWOp(unittest.TestCase):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
171
        a = paddle.to_tensor(value)
172
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
173 174 175 176 177
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)
W
WangXi 已提交
178 179 180 181 182 183

        for _ in range(2):
            out = linear(a)
            out.backward()
            adam.step()
            adam.clear_gradients()
M
MRXLT 已提交
184 185 186 187

    def test_adamw_op_coverage(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
188
        a = paddle.to_tensor(value)
189
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
190 191 192 193 194 195 196 197
        adam = paddle.optimizer.AdamW(
            learning_rate=0.0,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)
        assert (adam.__str__() is not None)

    def test_adamw_op(self):
198
        paddle.enable_static()
M
MRXLT 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        place = fluid.CPUPlace()
        shape = [2, 3, 8, 8]
        exe = fluid.Executor(place)
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                data = fluid.data(name="data", shape=shape)
                conv = fluid.layers.conv2d(data, 8, 3)
                loss = paddle.mean(conv)

                beta1 = fluid.layers.create_global_var(
                    shape=[1], value=0.85, dtype='float32', persistable=True)
                beta2 = fluid.layers.create_global_var(
                    shape=[1], value=0.95, dtype='float32', persistable=True)
                betas = [beta1, beta2]
                opt = paddle.optimizer.AdamW(
                    learning_rate=1e-5,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01,
                    epsilon=1e-8)
                opt.minimize(loss)

        exe.run(startup)
        data_np = np.random.random(shape).astype('float32')
        rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
        assert rets[0] is not None
227
        paddle.disable_static()
M
MRXLT 已提交
228

M
MRXLT 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241
    def test_adamw_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, beta1=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, beta2=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, epsilon=-1, parameters=linear.parameters())

C
chentianyu03 已提交
242 243 244 245 246
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_adamw_op_dygraph()
            self.test_adamw_op_invalid_input()

M
MRXLT 已提交
247

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
class TestAdamWOpGroup(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=[{
                'params': linear_1.parameters()
            }, {
                'params': linear_2.parameters(),
                'weight_decay': 0.001
            }],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
class TestAdamWOpMultiPrecison(unittest.TestCase):
    def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device(place)

        input = paddle.randn((5, 5))

        model = paddle.nn.Linear(5, 5)

        optimizer = paddle.optimizer.AdamW(
            parameters=[{
                'params': model.parameters(),
                'weight_decay': 0.001,
                'beta1': 0.1,
                'beta2': 0.99
            }],
            multi_precision=use_amp)

        for idx in range(2):
            if place == 'gpu' and use_amp == True:
                model = paddle.amp.decorate(models=model, level='O2')
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

            if place == 'gpu' and use_amp == True:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.step(optimizer)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()

    def _get_places(self):
        places = ['cpu']
        if paddle.is_compiled_with_cuda():
            places.append('gpu')
        return places

    def test_main(self):
        for place in self._get_places():
            use_amp_list = [True, False]
            for use_amp in use_amp_list:
                self._test_adamw_op_dygraph_place_amp(place, use_amp)


class TestAdamWOpError(unittest.TestCase):
    def test_api_errors(self):
        def test_weight_decay_dtype():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=linear.parameters(),
                weight_decay=1)

        def test_parameters_dtype1():
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=paddle.randn((5, 5)),
                weight_decay=0.1)

        def test_parameters_dtype2():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters={'params': linear.parameters()},
                weight_decay=0.1)

        def test_parameters_dtype3():
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01, parameters=None, weight_decay=0.1)

        def test_parameters_dtype4():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters={'params': set(linear.parameters())},
                weight_decay=0.1)

        def test_learning_rate_dtype():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=1,
                parameters=linear.parameters(),
                weight_decay=0.1)

        def test_grad_clip_dtype():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=linear.parameters(),
                weight_decay=0.1,
                grad_clip=0.1)

        self.assertRaises(TypeError, test_weight_decay_dtype)
        self.assertRaises(TypeError, test_parameters_dtype1)
        self.assertRaises(TypeError, test_parameters_dtype2)
        self.assertRaises(AttributeError, test_parameters_dtype3)
        self.assertRaises(TypeError, test_parameters_dtype4)
        self.assertRaises(TypeError, test_learning_rate_dtype)
        self.assertRaises(TypeError, test_grad_clip_dtype)


W
wangguanzhong 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
class TestAdamWOpGroupWithLR(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        adam = paddle.optimizer.AdamW(
            learning_rate=paddle.optimizer.lr.PiecewiseDecay(
                boundaries=[3, 6], values=[0.1, 0.2, 0.3]),
            parameters=[{
                'params': linear_1.parameters(),
                'learning_rate': 0.1,
            }, {
                'params': linear_2.parameters(),
                'weight_decay': 0.001,
            }],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


411 412 413 414 415 416 417 418 419 420 421
def simple_lr_setting(param, decay_rate, n_layers):
    if "fc_0" in param.name or "linear_1" in param.name:
        depth = int(param.name.split("_")[2]) + 1
    elif "fc_1" in param.name or "linear_2" in param.name:
        depth = int(param.name.split("_")[2]) + 2
    else:
        depth = 0

    return decay_rate**(n_layers + 2 - depth)


Z
zhaoyingli 已提交
422 423
@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
424
class TestAdamWOpLayerwiseLR(TestAdamWOp):
Z
zhaoyingli 已提交
425
    def setUp(self):
Z
zhaoyingli 已提交
426 427 428
        random.seed(2022)
        np.random.seed(2022)
        paddle.seed(2022)
Z
zhaoyingli 已提交
429

430 431
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
Z
zhaoyingli 已提交
432 433 434 435
        linear1 = paddle.nn.Linear(
            13, 8, bias_attr=paddle.nn.initializer.Constant(value=1.0))
        linear2 = paddle.nn.Linear(
            8, 5, bias_attr=paddle.nn.initializer.Constant(value=1.0))
436

C
chentianyu03 已提交
437 438 439 440 441 442
        # fix the linear name, simple_lr_setting function will use the name
        linear1.weight.name = "linear_1.w_0"
        linear1.bias.name = "linear_1.b_0"
        linear2.weight.name = "linear_2.w_0"
        linear2.bias.name = "linear_2.b_0"

Z
zhaoyingli 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456
        fc1_w = np.array(linear1.weight)
        fc1_w_mon1 = np.zeros_like(fc1_w)
        fc1_w_mon2 = np.zeros_like(fc1_w)
        fc1_b = np.array(linear1.bias)
        fc1_b_mon1 = np.zeros_like(fc1_b)
        fc1_b_mon2 = np.zeros_like(fc1_b)

        fc2_w = np.array(linear2.weight)
        fc2_w_mon1 = np.zeros_like(fc2_w)
        fc2_w_mon2 = np.zeros_like(fc2_w)
        fc2_b = np.array(linear2.bias)
        fc2_b_mon1 = np.zeros_like(fc2_b)
        fc2_b_mon2 = np.zeros_like(fc2_b)

457
        simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)
Z
zhaoyingli 已提交
458 459 460 461
        learning_rate = 0.001
        weight_decay = 0.01
        beta1 = 0.9
        beta2 = 0.999
462

Z
zhaoyingli 已提交
463 464
        opt = paddle.optimizer.AdamW(
            learning_rate=learning_rate,
465 466 467 468 469 470
            parameters=[{
                'params': linear1.parameters()
            }, {
                'params': linear2.parameters(),
            }],
            apply_decay_param_fun=lambda name: True,
Z
zhaoyingli 已提交
471
            weight_decay=weight_decay,
472 473
            lr_ratio=simple_lr_fun)

Z
zhaoyingli 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
        def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t):
            np_inputs = {
                'Param': param,
                'Grad': grad,
                'Moment1': moment1,
                'Moment2': moment2,
                'LearningRate': np.array([learning_rate]).astype("float32"),
                'Beta1Pow': np.array([beta1**t]).astype("float32"),
                'Beta2Pow': np.array([beta2**t]).astype("float32")
            }

            np_attrs = {
                'epsilon': 1e-8,
                'beta1': beta1,
                'beta2': beta2,
                "lr_ratio": lr_ratio,
                "coeff": weight_decay,
                "with_decay": True
            }
            param_out, moment1_out, moment2_out = adamw_step(np_inputs,
                                                             np_attrs)
            return param_out, moment1_out, moment2_out

Z
zhaoyingli 已提交
497
        for i in range(5):
Z
zhaoyingli 已提交
498 499
            a = paddle.to_tensor(
                np.random.uniform(-1, 1, (2, 13)).astype("float32"))
500 501
            a1 = linear1(a)
            out = linear2(a1)
Z
zhaoyingli 已提交
502
            out = paddle.mean(out)
503
            out.backward()
Z
zhaoyingli 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528

            fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output(
                fc1_w,
                np.array(linear1.weight.grad), fc1_w_mon1, fc1_w_mon2,
                simple_lr_fun(linear1.weight), i + 1)
            fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output(
                fc1_b,
                np.array(linear1.bias.grad), fc1_b_mon1, fc1_b_mon2,
                simple_lr_fun(linear1.bias), i + 1)
            fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output(
                fc2_w,
                np.array(linear2.weight.grad), fc2_w_mon1, fc2_w_mon2,
                simple_lr_fun(linear2.weight), i + 1)
            fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output(
                fc2_b,
                np.array(linear2.bias.grad), fc2_b_mon1, fc2_b_mon2,
                simple_lr_fun(linear2.bias), i + 1)

            opt.step()
            opt.clear_gradients()

            np.testing.assert_allclose(linear1.weight.numpy(), fc1_w, rtol=1e-6)
            np.testing.assert_allclose(linear1.bias.numpy(), fc1_b, rtol=1e-6)
            np.testing.assert_allclose(linear2.weight.numpy(), fc2_w, rtol=1e-6)
            np.testing.assert_allclose(linear2.bias.numpy(), fc2_b, rtol=1e-6)
529 530 531

    def test_adamw_op(self):
        paddle.enable_static()
Z
zhaoyingli 已提交
532
        place = fluid.CUDAPlace(0)
Z
zhaoyingli 已提交
533 534 535 536 537 538 539

        learning_rate = 0.0001
        beta1 = 0.85
        beta2 = 0.95
        weight_decay = 0.01
        epsilon = 1e-8

540 541 542 543 544 545 546
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                x = fluid.data(name='x', shape=[None, 10], dtype='float32')
                y = fluid.data(name='y', shape=[None, 1], dtype='float32')

Z
zhaoyingli 已提交
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
                weight_attr1 = paddle.framework.ParamAttr(name="linear_0.w_0")
                bias_attr1 = paddle.framework.ParamAttr(
                    name="linear_0.b_0",
                    initializer=paddle.nn.initializer.Constant(value=1.0))
                weight_attr2 = paddle.framework.ParamAttr(name="linear_1.w_0")
                bias_attr2 = paddle.framework.ParamAttr(
                    name="linear_1.b_0",
                    initializer=paddle.nn.initializer.Constant(value=1.0))
                linear1 = paddle.nn.Linear(
                    10, 32, weight_attr=weight_attr1, bias_attr=bias_attr1)
                linear2 = paddle.nn.Linear(
                    32, 1, weight_attr=weight_attr2, bias_attr=bias_attr2)

                out = linear1(x)
                out = linear2(out)

                fc1_w_mon1 = np.zeros((linear1.weight.shape)).astype("float32")
                fc1_w_mon2 = np.zeros((linear1.weight.shape)).astype("float32")
                fc1_b_mon1 = np.zeros((linear1.bias.shape)).astype("float32")
                fc1_b_mon2 = np.zeros((linear1.bias.shape)).astype("float32")
                fc2_w_mon1 = np.zeros((linear2.weight.shape)).astype("float32")
                fc2_w_mon2 = np.zeros((linear2.weight.shape)).astype("float32")
                fc2_b_mon1 = np.zeros((linear2.bias.shape)).astype("float32")
                fc2_b_mon2 = np.zeros((linear2.bias.shape)).astype("float32")

                cost = fluid.layers.square_error_cost(input=out, label=y)
573 574 575 576 577 578
                avg_cost = fluid.layers.mean(cost)

                simple_lr_fun = partial(
                    simple_lr_setting, decay_rate=0.8, n_layers=2)

                opt = paddle.optimizer.AdamW(
Z
zhaoyingli 已提交
579
                    learning_rate=learning_rate,
580 581
                    beta1=beta1,
                    beta2=beta2,
Z
zhaoyingli 已提交
582 583
                    weight_decay=weight_decay,
                    epsilon=epsilon,
584 585 586
                    lr_ratio=simple_lr_fun)
                opt.minimize(avg_cost)

Z
zhaoyingli 已提交
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
        def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t):
            np_inputs = {
                'Param': param,
                'Grad': grad,
                'Moment1': moment1,
                'Moment2': moment2,
                'LearningRate': np.array([learning_rate]).astype("float32"),
                'Beta1Pow': np.array([beta1**t]).astype("float32"),
                'Beta2Pow': np.array([beta2**t]).astype("float32")
            }

            np_attrs = {
                'epsilon': epsilon,
                'beta1': beta1,
                'beta2': beta2,
                "lr_ratio": lr_ratio,
                "coeff": weight_decay,
                "with_decay": True
            }
            param_out, moment1_out, moment2_out = adamw_step(np_inputs,
                                                             np_attrs)
            return param_out, moment1_out, moment2_out

        fetch_list1 = [
            "linear_0.w_0", "linear_0.b_0", "linear_1.w_0", "linear_1.b_0"
        ]
        fetch_list2 = [
            "linear_0.w_0", "linear_0.w_0@GRAD", "linear_0.b_0",
            "linear_0.b_0@GRAD", "linear_1.w_0", "linear_1.w_0@GRAD",
            "linear_1.b_0", "linear_1.b_0@GRAD"
        ]

619 620
        exe = fluid.Executor(place)
        exe.run(startup)
Z
zhaoyingli 已提交
621
        test_prog = train_prog.clone(for_test=True)
Z
zhaoyingli 已提交
622 623

        for i in range(5):
624 625
            inputs = np.random.random(size=[8, 10]).astype('float32')
            outputs = np.random.random(size=[8, 1]).astype('float32')
Z
zhaoyingli 已提交
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661

            param = exe.run(test_prog,
                            feed={"x": inputs,
                                  "y": outputs},
                            fetch_list=fetch_list1)
            params_and_gras = exe.run(train_prog,
                                      feed={"x": inputs,
                                            "y": outputs},
                                      fetch_list=fetch_list2)

            fc1_w = param[0]
            fc1_w_grad = params_and_gras[1]
            fc1_b = param[1]
            fc1_b_grad = params_and_gras[3]
            fc2_w = param[2]
            fc2_w_grad = params_and_gras[5]
            fc2_b = param[3]
            fc2_b_grad = params_and_gras[7]

            fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output(
                fc1_w, fc1_w_grad, fc1_w_mon1, fc1_w_mon2,
                simple_lr_fun(linear1.weight), i + 1)
            fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output(
                fc1_b, fc1_b_grad, fc1_b_mon1, fc1_b_mon2,
                simple_lr_fun(linear1.bias), i + 1)
            fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output(
                fc2_w, fc2_w_grad, fc2_w_mon1, fc2_w_mon2,
                simple_lr_fun(linear2.weight), i + 1)
            fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output(
                fc2_b, fc2_b_grad, fc2_b_mon1, fc2_b_mon2,
                simple_lr_fun(linear2.bias), i + 1)

            np.testing.assert_allclose(params_and_gras[0], fc1_w, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[2], fc1_b, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[4], fc2_w, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[6], fc2_b, rtol=1e-6)
662 663 664 665

        paddle.disable_static()


M
MRXLT 已提交
666 667
if __name__ == "__main__":
    unittest.main()