test_adamw_op.py 24.8 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zhaoyingli 已提交
15
import random
16 17 18
import unittest
from functools import partial

M
MRXLT 已提交
19
import numpy as np
Z
zhaoyingli 已提交
20
from op_test import OpTest
21 22 23

import paddle
import paddle.fluid as fluid
C
chentianyu03 已提交
24
from paddle.fluid.framework import _test_eager_guard
25
from paddle.framework import core
Z
zhaoyingli 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58


def adamw_step(inputs, attributes):
    param = inputs['Param']
    grad = inputs['Grad']
    moment1 = inputs['Moment1']
    moment2 = inputs['Moment2']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']
    beta2_pow = inputs['Beta2Pow']

    epsilon = attributes['epsilon']

    if 'lr_ratio' in attributes:
        lr = lr * attributes['lr_ratio']

    if attributes["with_decay"]:
        coeff = attributes["coeff"]
        decay = 1.0 - lr * coeff
        param2 = param * decay
        param = param2.copy()

    if 'beta1' in attributes:
        beta1 = attributes['beta1']
    else:
        beta1 = inputs['Beta1Tensor'][0]
    if 'beta2' in attributes:
        beta2 = attributes['beta2']
    else:
        beta2 = inputs['Beta2Tensor'][0]

    moment1_out = beta1 * moment1 + (1 - beta1) * grad
    moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
Z
zhaoyingli 已提交
59 60
    denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon
    param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow))))
Z
zhaoyingli 已提交
61 62 63 64 65
    return param_out, moment1_out, moment2_out


class TestAdamW(OpTest):
    def setUp(self):
66
        '''Test AdamW Op with supplied attributes'''
Z
zhaoyingli 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
88
            'Beta2Pow': np.array([beta2_pow]).astype("float32"),
Z
zhaoyingli 已提交
89 90 91 92 93 94 95
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "coeff": 0.5,
96
            "with_decay": True,
Z
zhaoyingli 已提交
97 98
        }

99 100 101
        param_out, moment1_out, moment2_out = adamw_step(
            self.inputs, self.attrs
        )
Z
zhaoyingli 已提交
102 103 104 105 106 107

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
108
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2,
Z
zhaoyingli 已提交
109 110 111 112 113 114
        }

    def test_check_output(self):
        self.check_output()


115 116 117
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
Z
zhaoyingli 已提交
118 119
class TestAdamW2(OpTest):
    def setUp(self):
120
        '''Test AdamW Op with supplied attributes'''
Z
zhaoyingli 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        grad = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((2, 2)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
142
            'Beta2Pow': np.array([beta2_pow]).astype("float32"),
Z
zhaoyingli 已提交
143 144 145 146 147 148 149 150
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "lr_ratio": 0.1,
            "coeff": 0.5,
151
            "with_decay": True,
Z
zhaoyingli 已提交
152 153
        }

154
        param_out, moment1_out, moment2_out = adamw_step(
155 156
            self.inputs, self.attrs
        )
Z
zhaoyingli 已提交
157 158 159 160 161 162

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
163
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2,
Z
zhaoyingli 已提交
164 165 166 167
        }

    def test_check_output(self):
        self.check_output_with_place(core.CUDAPlace(0))
M
MRXLT 已提交
168 169 170 171 172 173


class TestAdamWOp(unittest.TestCase):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
174
        a = paddle.to_tensor(value)
175
        linear = paddle.nn.Linear(13, 5)
176 177 178 179 180 181
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01,
        )
W
WangXi 已提交
182 183 184 185 186 187

        for _ in range(2):
            out = linear(a)
            out.backward()
            adam.step()
            adam.clear_gradients()
M
MRXLT 已提交
188 189 190 191

    def test_adamw_op_coverage(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
192
        a = paddle.to_tensor(value)
193
        linear = paddle.nn.Linear(13, 5)
194 195 196 197 198 199 200
        adam = paddle.optimizer.AdamW(
            learning_rate=0.0,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01,
        )
        assert adam.__str__() is not None
M
MRXLT 已提交
201 202

    def test_adamw_op(self):
203
        paddle.enable_static()
M
MRXLT 已提交
204 205 206 207 208 209 210 211
        place = fluid.CPUPlace()
        shape = [2, 3, 8, 8]
        exe = fluid.Executor(place)
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                data = fluid.data(name="data", shape=shape)
212
                conv = paddle.static.nn.conv2d(data, 8, 3)
M
MRXLT 已提交
213 214
                loss = paddle.mean(conv)

215
                beta1 = paddle.static.create_global_var(
216 217
                    shape=[1], value=0.85, dtype='float32', persistable=True
                )
218
                beta2 = paddle.static.create_global_var(
219 220
                    shape=[1], value=0.95, dtype='float32', persistable=True
                )
M
MRXLT 已提交
221
                betas = [beta1, beta2]
222 223 224 225 226 227 228
                opt = paddle.optimizer.AdamW(
                    learning_rate=1e-5,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01,
                    epsilon=1e-8,
                )
M
MRXLT 已提交
229 230 231 232 233 234
                opt.minimize(loss)

        exe.run(startup)
        data_np = np.random.random(shape).astype('float32')
        rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
        assert rets[0] is not None
235
        paddle.disable_static()
M
MRXLT 已提交
236

M
MRXLT 已提交
237 238 239 240
    def test_adamw_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
241 242 243
            adam = paddle.optimizer.AdamW(
                0.1, beta1=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
244
        with self.assertRaises(ValueError):
245 246 247
            adam = paddle.optimizer.AdamW(
                0.1, beta2=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
248
        with self.assertRaises(ValueError):
249 250 251
            adam = paddle.optimizer.AdamW(
                0.1, epsilon=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
252

C
chentianyu03 已提交
253 254 255 256 257
    def test_api_eager_dygraph(self):
        with _test_eager_guard():
            self.test_adamw_op_dygraph()
            self.test_adamw_op_invalid_input()

M
MRXLT 已提交
258

259 260 261 262 263 264 265
class TestAdamWOpGroup(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
266 267 268 269 270 271 272 273 274
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=[
                {'params': linear_1.parameters()},
                {'params': linear_2.parameters(), 'weight_decay': 0.001},
            ],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01,
        )
275 276 277 278 279 280 281 282 283

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


284 285 286 287 288 289 290 291 292 293
class TestAdamWOpMultiPrecison(unittest.TestCase):
    def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device(place)

        input = paddle.randn((5, 5))

        model = paddle.nn.Linear(5, 5)

294 295 296 297 298 299 300 301 302 303 304
        optimizer = paddle.optimizer.AdamW(
            parameters=[
                {
                    'params': model.parameters(),
                    'weight_decay': 0.001,
                    'beta1': 0.1,
                    'beta2': 0.99,
                }
            ],
            multi_precision=use_amp,
        )
305 306

        for idx in range(2):
307
            if place == 'gpu' and use_amp:
308 309 310
                model = paddle.amp.decorate(models=model, level='O2')
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

311
            if place == 'gpu' and use_amp:
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.step(optimizer)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()

    def _get_places(self):
        places = ['cpu']
        if paddle.is_compiled_with_cuda():
            places.append('gpu')
        return places

    def test_main(self):
        for place in self._get_places():
            use_amp_list = [True, False]
            for use_amp in use_amp_list:
                self._test_adamw_op_dygraph_place_amp(place, use_amp)


class TestAdamWOpError(unittest.TestCase):
    def test_api_errors(self):
        def test_weight_decay_dtype():
            linear = paddle.nn.Linear(13, 5)
343 344 345 346 347
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=linear.parameters(),
                weight_decay=1,
            )
348 349

        def test_parameters_dtype1():
350 351 352 353 354
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=paddle.randn((5, 5)),
                weight_decay=0.1,
            )
355 356 357 358 359 360

        def test_parameters_dtype2():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters={'params': linear.parameters()},
361 362
                weight_decay=0.1,
            )
363 364

        def test_parameters_dtype3():
365 366 367
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01, parameters=None, weight_decay=0.1
            )
368 369 370 371 372 373

        def test_parameters_dtype4():
            linear = paddle.nn.Linear(13, 5)
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters={'params': set(linear.parameters())},
374 375
                weight_decay=0.1,
            )
376 377 378

        def test_learning_rate_dtype():
            linear = paddle.nn.Linear(13, 5)
379 380 381 382 383
            adam = paddle.optimizer.AdamW(
                learning_rate=1,
                parameters=linear.parameters(),
                weight_decay=0.1,
            )
384 385 386

        def test_grad_clip_dtype():
            linear = paddle.nn.Linear(13, 5)
387 388 389 390 391 392
            adam = paddle.optimizer.AdamW(
                learning_rate=0.01,
                parameters=linear.parameters(),
                weight_decay=0.1,
                grad_clip=0.1,
            )
393 394 395 396 397 398 399 400 401 402

        self.assertRaises(TypeError, test_weight_decay_dtype)
        self.assertRaises(TypeError, test_parameters_dtype1)
        self.assertRaises(TypeError, test_parameters_dtype2)
        self.assertRaises(AttributeError, test_parameters_dtype3)
        self.assertRaises(TypeError, test_parameters_dtype4)
        self.assertRaises(TypeError, test_learning_rate_dtype)
        self.assertRaises(TypeError, test_grad_clip_dtype)


W
wangguanzhong 已提交
403 404 405 406 407 408 409 410 411
class TestAdamWOpGroupWithLR(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        adam = paddle.optimizer.AdamW(
            learning_rate=paddle.optimizer.lr.PiecewiseDecay(
412 413 414 415 416 417 418 419 420 421 422 423
                boundaries=[3, 6], values=[0.1, 0.2, 0.3]
            ),
            parameters=[
                {
                    'params': linear_1.parameters(),
                    'learning_rate': 0.1,
                },
                {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                },
            ],
W
wangguanzhong 已提交
424
            apply_decay_param_fun=lambda name: True,
425 426
            weight_decay=0.01,
        )
W
wangguanzhong 已提交
427 428 429 430 431 432 433 434 435

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


436 437 438 439 440 441 442 443
def simple_lr_setting(param, decay_rate, n_layers):
    if "fc_0" in param.name or "linear_1" in param.name:
        depth = int(param.name.split("_")[2]) + 1
    elif "fc_1" in param.name or "linear_2" in param.name:
        depth = int(param.name.split("_")[2]) + 2
    else:
        depth = 0

444
    return decay_rate ** (n_layers + 2 - depth)
445 446


447 448 449
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
450
class TestAdamWOpLayerwiseLR(TestAdamWOp):
Z
zhaoyingli 已提交
451
    def setUp(self):
Z
zhaoyingli 已提交
452 453 454
        random.seed(2022)
        np.random.seed(2022)
        paddle.seed(2022)
Z
zhaoyingli 已提交
455

456 457
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
Z
zhaoyingli 已提交
458
        linear1 = paddle.nn.Linear(
459 460
            13, 8, bias_attr=paddle.nn.initializer.Constant(value=1.0)
        )
Z
zhaoyingli 已提交
461
        linear2 = paddle.nn.Linear(
462 463
            8, 5, bias_attr=paddle.nn.initializer.Constant(value=1.0)
        )
464

C
chentianyu03 已提交
465 466 467 468 469 470
        # fix the linear name, simple_lr_setting function will use the name
        linear1.weight.name = "linear_1.w_0"
        linear1.bias.name = "linear_1.b_0"
        linear2.weight.name = "linear_2.w_0"
        linear2.bias.name = "linear_2.b_0"

Z
zhaoyingli 已提交
471 472 473 474 475 476 477 478 479 480 481 482 483 484
        fc1_w = np.array(linear1.weight)
        fc1_w_mon1 = np.zeros_like(fc1_w)
        fc1_w_mon2 = np.zeros_like(fc1_w)
        fc1_b = np.array(linear1.bias)
        fc1_b_mon1 = np.zeros_like(fc1_b)
        fc1_b_mon2 = np.zeros_like(fc1_b)

        fc2_w = np.array(linear2.weight)
        fc2_w_mon1 = np.zeros_like(fc2_w)
        fc2_w_mon2 = np.zeros_like(fc2_w)
        fc2_b = np.array(linear2.bias)
        fc2_b_mon1 = np.zeros_like(fc2_b)
        fc2_b_mon2 = np.zeros_like(fc2_b)

485
        simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)
Z
zhaoyingli 已提交
486 487 488 489
        learning_rate = 0.001
        weight_decay = 0.01
        beta1 = 0.9
        beta2 = 0.999
490

491 492 493 494 495 496 497 498 499 500 501 502
        opt = paddle.optimizer.AdamW(
            learning_rate=learning_rate,
            parameters=[
                {'params': linear1.parameters()},
                {
                    'params': linear2.parameters(),
                },
            ],
            apply_decay_param_fun=lambda name: True,
            weight_decay=weight_decay,
            lr_ratio=simple_lr_fun,
        )
503

Z
zhaoyingli 已提交
504 505 506 507 508 509 510 511
        def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t):
            np_inputs = {
                'Param': param,
                'Grad': grad,
                'Moment1': moment1,
                'Moment2': moment2,
                'LearningRate': np.array([learning_rate]).astype("float32"),
                'Beta1Pow': np.array([beta1**t]).astype("float32"),
512
                'Beta2Pow': np.array([beta2**t]).astype("float32"),
Z
zhaoyingli 已提交
513 514 515 516 517 518 519 520
            }

            np_attrs = {
                'epsilon': 1e-8,
                'beta1': beta1,
                'beta2': beta2,
                "lr_ratio": lr_ratio,
                "coeff": weight_decay,
521
                "with_decay": True,
Z
zhaoyingli 已提交
522
            }
523
            param_out, moment1_out, moment2_out = adamw_step(
524 525
                np_inputs, np_attrs
            )
Z
zhaoyingli 已提交
526 527
            return param_out, moment1_out, moment2_out

Z
zhaoyingli 已提交
528
        for i in range(5):
Z
zhaoyingli 已提交
529
            a = paddle.to_tensor(
530 531
                np.random.uniform(-1, 1, (2, 13)).astype("float32")
            )
532 533
            a1 = linear1(a)
            out = linear2(a1)
Z
zhaoyingli 已提交
534
            out = paddle.mean(out)
535
            out.backward()
Z
zhaoyingli 已提交
536 537

            fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output(
538 539 540 541 542 543 544
                fc1_w,
                np.array(linear1.weight.grad),
                fc1_w_mon1,
                fc1_w_mon2,
                simple_lr_fun(linear1.weight),
                i + 1,
            )
Z
zhaoyingli 已提交
545
            fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output(
546 547 548 549 550 551 552
                fc1_b,
                np.array(linear1.bias.grad),
                fc1_b_mon1,
                fc1_b_mon2,
                simple_lr_fun(linear1.bias),
                i + 1,
            )
Z
zhaoyingli 已提交
553
            fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output(
554 555 556 557 558 559 560
                fc2_w,
                np.array(linear2.weight.grad),
                fc2_w_mon1,
                fc2_w_mon2,
                simple_lr_fun(linear2.weight),
                i + 1,
            )
Z
zhaoyingli 已提交
561
            fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output(
562 563 564 565 566 567 568
                fc2_b,
                np.array(linear2.bias.grad),
                fc2_b_mon1,
                fc2_b_mon2,
                simple_lr_fun(linear2.bias),
                i + 1,
            )
Z
zhaoyingli 已提交
569 570 571 572 573 574 575 576

            opt.step()
            opt.clear_gradients()

            np.testing.assert_allclose(linear1.weight.numpy(), fc1_w, rtol=1e-6)
            np.testing.assert_allclose(linear1.bias.numpy(), fc1_b, rtol=1e-6)
            np.testing.assert_allclose(linear2.weight.numpy(), fc2_w, rtol=1e-6)
            np.testing.assert_allclose(linear2.bias.numpy(), fc2_b, rtol=1e-6)
577 578 579

    def test_adamw_op(self):
        paddle.enable_static()
Z
zhaoyingli 已提交
580
        place = fluid.CUDAPlace(0)
Z
zhaoyingli 已提交
581 582 583 584 585 586 587

        learning_rate = 0.0001
        beta1 = 0.85
        beta2 = 0.95
        weight_decay = 0.01
        epsilon = 1e-8

588 589 590 591 592 593 594
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                x = fluid.data(name='x', shape=[None, 10], dtype='float32')
                y = fluid.data(name='y', shape=[None, 1], dtype='float32')

Z
zhaoyingli 已提交
595 596 597
                weight_attr1 = paddle.framework.ParamAttr(name="linear_0.w_0")
                bias_attr1 = paddle.framework.ParamAttr(
                    name="linear_0.b_0",
598 599
                    initializer=paddle.nn.initializer.Constant(value=1.0),
                )
Z
zhaoyingli 已提交
600 601 602
                weight_attr2 = paddle.framework.ParamAttr(name="linear_1.w_0")
                bias_attr2 = paddle.framework.ParamAttr(
                    name="linear_1.b_0",
603 604 605 606 607 608 609 610
                    initializer=paddle.nn.initializer.Constant(value=1.0),
                )
                linear1 = paddle.nn.Linear(
                    10, 32, weight_attr=weight_attr1, bias_attr=bias_attr1
                )
                linear2 = paddle.nn.Linear(
                    32, 1, weight_attr=weight_attr2, bias_attr=bias_attr2
                )
Z
zhaoyingli 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623

                out = linear1(x)
                out = linear2(out)

                fc1_w_mon1 = np.zeros((linear1.weight.shape)).astype("float32")
                fc1_w_mon2 = np.zeros((linear1.weight.shape)).astype("float32")
                fc1_b_mon1 = np.zeros((linear1.bias.shape)).astype("float32")
                fc1_b_mon2 = np.zeros((linear1.bias.shape)).astype("float32")
                fc2_w_mon1 = np.zeros((linear2.weight.shape)).astype("float32")
                fc2_w_mon2 = np.zeros((linear2.weight.shape)).astype("float32")
                fc2_b_mon1 = np.zeros((linear2.bias.shape)).astype("float32")
                fc2_b_mon2 = np.zeros((linear2.bias.shape)).astype("float32")

624 625 626
                cost = paddle.nn.functional.square_error_cost(
                    input=out, label=y
                )
627
                avg_cost = paddle.mean(cost)
628

629 630 631 632 633 634 635 636 637 638 639 640
                simple_lr_fun = partial(
                    simple_lr_setting, decay_rate=0.8, n_layers=2
                )

                opt = paddle.optimizer.AdamW(
                    learning_rate=learning_rate,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=weight_decay,
                    epsilon=epsilon,
                    lr_ratio=simple_lr_fun,
                )
641 642
                opt.minimize(avg_cost)

Z
zhaoyingli 已提交
643 644 645 646 647 648 649 650
        def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t):
            np_inputs = {
                'Param': param,
                'Grad': grad,
                'Moment1': moment1,
                'Moment2': moment2,
                'LearningRate': np.array([learning_rate]).astype("float32"),
                'Beta1Pow': np.array([beta1**t]).astype("float32"),
651
                'Beta2Pow': np.array([beta2**t]).astype("float32"),
Z
zhaoyingli 已提交
652 653 654 655 656 657 658 659
            }

            np_attrs = {
                'epsilon': epsilon,
                'beta1': beta1,
                'beta2': beta2,
                "lr_ratio": lr_ratio,
                "coeff": weight_decay,
660
                "with_decay": True,
Z
zhaoyingli 已提交
661
            }
662
            param_out, moment1_out, moment2_out = adamw_step(
663 664
                np_inputs, np_attrs
            )
Z
zhaoyingli 已提交
665 666 667
            return param_out, moment1_out, moment2_out

        fetch_list1 = [
668 669 670 671
            "linear_0.w_0",
            "linear_0.b_0",
            "linear_1.w_0",
            "linear_1.b_0",
Z
zhaoyingli 已提交
672 673
        ]
        fetch_list2 = [
674 675 676 677 678 679 680 681
            "linear_0.w_0",
            "linear_0.w_0@GRAD",
            "linear_0.b_0",
            "linear_0.b_0@GRAD",
            "linear_1.w_0",
            "linear_1.w_0@GRAD",
            "linear_1.b_0",
            "linear_1.b_0@GRAD",
Z
zhaoyingli 已提交
682 683
        ]

684 685
        exe = fluid.Executor(place)
        exe.run(startup)
Z
zhaoyingli 已提交
686
        test_prog = train_prog.clone(for_test=True)
Z
zhaoyingli 已提交
687 688

        for i in range(5):
689 690
            inputs = np.random.random(size=[8, 10]).astype('float32')
            outputs = np.random.random(size=[8, 1]).astype('float32')
Z
zhaoyingli 已提交
691

692 693 694 695 696 697 698 699 700 701
            param = exe.run(
                test_prog,
                feed={"x": inputs, "y": outputs},
                fetch_list=fetch_list1,
            )
            params_and_gras = exe.run(
                train_prog,
                feed={"x": inputs, "y": outputs},
                fetch_list=fetch_list2,
            )
Z
zhaoyingli 已提交
702 703 704 705 706 707 708 709 710 711 712

            fc1_w = param[0]
            fc1_w_grad = params_and_gras[1]
            fc1_b = param[1]
            fc1_b_grad = params_and_gras[3]
            fc2_w = param[2]
            fc2_w_grad = params_and_gras[5]
            fc2_b = param[3]
            fc2_b_grad = params_and_gras[7]

            fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output(
713 714 715 716 717 718 719
                fc1_w,
                fc1_w_grad,
                fc1_w_mon1,
                fc1_w_mon2,
                simple_lr_fun(linear1.weight),
                i + 1,
            )
Z
zhaoyingli 已提交
720
            fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output(
721 722 723 724 725 726 727
                fc1_b,
                fc1_b_grad,
                fc1_b_mon1,
                fc1_b_mon2,
                simple_lr_fun(linear1.bias),
                i + 1,
            )
Z
zhaoyingli 已提交
728
            fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output(
729 730 731 732 733 734 735
                fc2_w,
                fc2_w_grad,
                fc2_w_mon1,
                fc2_w_mon2,
                simple_lr_fun(linear2.weight),
                i + 1,
            )
Z
zhaoyingli 已提交
736
            fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output(
737 738 739 740 741 742 743
                fc2_b,
                fc2_b_grad,
                fc2_b_mon1,
                fc2_b_mon2,
                simple_lr_fun(linear2.bias),
                i + 1,
            )
Z
zhaoyingli 已提交
744 745 746 747 748

            np.testing.assert_allclose(params_and_gras[0], fc1_w, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[2], fc1_b, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[4], fc2_w, rtol=1e-6)
            np.testing.assert_allclose(params_and_gras[6], fc2_b, rtol=1e-6)
749 750 751 752

        paddle.disable_static()


M
MRXLT 已提交
753 754
if __name__ == "__main__":
    unittest.main()