test_adamw_op.py 13.4 KB
Newer Older
M
MRXLT 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
Z
zhaoyingli 已提交
17
import random
M
MRXLT 已提交
18 19
import numpy as np
import paddle.fluid as fluid
Z
zhaoyingli 已提交
20
from op_test import OpTest
21
from functools import partial
Z
zhaoyingli 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
from paddle.framework import core


def adamw_step(inputs, attributes):
    param = inputs['Param']
    grad = inputs['Grad']
    moment1 = inputs['Moment1']
    moment2 = inputs['Moment2']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']
    beta2_pow = inputs['Beta2Pow']

    epsilon = attributes['epsilon']

    if 'lr_ratio' in attributes:
        lr = lr * attributes['lr_ratio']

    if attributes["with_decay"]:
        coeff = attributes["coeff"]
        decay = 1.0 - lr * coeff
        param2 = param * decay
        param = param2.copy()

    if 'beta1' in attributes:
        beta1 = attributes['beta1']
    else:
        beta1 = inputs['Beta1Tensor'][0]
    if 'beta2' in attributes:
        beta2 = attributes['beta2']
    else:
        beta2 = inputs['Beta2Tensor'][0]

    moment1_out = beta1 * moment1 + (1 - beta1) * grad
    moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
    lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
    param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
    return param_out, moment1_out, moment2_out


class TestAdamW(OpTest):
    def setUp(self):
        '''Test AdamW Op with supplied attributes
        '''
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "coeff": 0.5,
            "with_decay": True
        }

        param_out, moment1_out, \
            moment2_out = adamw_step(self.inputs, self.attrs)

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
        }

    def test_check_output(self):
        self.check_output()


@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
class TestAdamW2(OpTest):
    def setUp(self):
        '''Test AdamW Op with supplied attributes
        '''
        self.op_type = "adamw"
        param = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        grad = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (2, 2)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((2, 2)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            "lr_ratio": 0.1,
            "coeff": 0.5,
            "with_decay": True
        }

        param_out, moment1_out, moment2_out = adamw_step(self.inputs,
                                                         self.attrs)

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
        }

    def test_check_output(self):
        self.check_output_with_place(core.CUDAPlace(0))
M
MRXLT 已提交
164 165 166 167 168 169


class TestAdamWOp(unittest.TestCase):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
170
        a = paddle.to_tensor(value)
171
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
172 173 174 175 176
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)
W
WangXi 已提交
177 178 179 180 181 182

        for _ in range(2):
            out = linear(a)
            out.backward()
            adam.step()
            adam.clear_gradients()
M
MRXLT 已提交
183 184 185 186

    def test_adamw_op_coverage(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
Z
Zhou Wei 已提交
187
        a = paddle.to_tensor(value)
188
        linear = paddle.nn.Linear(13, 5)
M
MRXLT 已提交
189 190 191 192 193 194 195 196
        adam = paddle.optimizer.AdamW(
            learning_rate=0.0,
            parameters=linear.parameters(),
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)
        assert (adam.__str__() is not None)

    def test_adamw_op(self):
197
        paddle.enable_static()
M
MRXLT 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        place = fluid.CPUPlace()
        shape = [2, 3, 8, 8]
        exe = fluid.Executor(place)
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                data = fluid.data(name="data", shape=shape)
                conv = fluid.layers.conv2d(data, 8, 3)
                loss = paddle.mean(conv)

                beta1 = fluid.layers.create_global_var(
                    shape=[1], value=0.85, dtype='float32', persistable=True)
                beta2 = fluid.layers.create_global_var(
                    shape=[1], value=0.95, dtype='float32', persistable=True)
                betas = [beta1, beta2]
                opt = paddle.optimizer.AdamW(
                    learning_rate=1e-5,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01,
                    epsilon=1e-8)
                opt.minimize(loss)

        exe.run(startup)
        data_np = np.random.random(shape).astype('float32')
        rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
        assert rets[0] is not None
226
        paddle.disable_static()
M
MRXLT 已提交
227

M
MRXLT 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240
    def test_adamw_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, beta1=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, beta2=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.AdamW(
                0.1, epsilon=-1, parameters=linear.parameters())

M
MRXLT 已提交
241

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
class TestAdamWOpGroup(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=[{
                'params': linear_1.parameters()
            }, {
                'params': linear_2.parameters(),
                'weight_decay': 0.001
            }],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


W
wangguanzhong 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
class TestAdamWOpGroupWithLR(TestAdamWOp):
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        adam = paddle.optimizer.AdamW(
            learning_rate=paddle.optimizer.lr.PiecewiseDecay(
                boundaries=[3, 6], values=[0.1, 0.2, 0.3]),
            parameters=[{
                'params': linear_1.parameters(),
                'learning_rate': 0.1,
            }, {
                'params': linear_2.parameters(),
                'weight_decay': 0.001,
            }],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01)

        for _ in range(2):
            out = linear_1(a)
            out = linear_2(out)
            out.backward()
            adam.step()
            adam.clear_gradients()


296 297 298 299 300 301 302 303 304 305 306
def simple_lr_setting(param, decay_rate, n_layers):
    if "fc_0" in param.name or "linear_1" in param.name:
        depth = int(param.name.split("_")[2]) + 1
    elif "fc_1" in param.name or "linear_2" in param.name:
        depth = int(param.name.split("_")[2]) + 2
    else:
        depth = 0

    return decay_rate**(n_layers + 2 - depth)


Z
zhaoyingli 已提交
307 308
@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
309
class TestAdamWOpLayerwiseLR(TestAdamWOp):
Z
zhaoyingli 已提交
310 311 312 313 314
    def setUp(self):
        random.seed(2021)
        np.random.seed(2021)
        paddle.seed(2021)

315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
    def test_adamw_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear1 = paddle.nn.Linear(13, 8)
        linear2 = paddle.nn.Linear(8, 5)

        simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)

        adam = paddle.optimizer.AdamW(
            learning_rate=0.01,
            parameters=[{
                'params': linear1.parameters()
            }, {
                'params': linear2.parameters(),
            }],
            apply_decay_param_fun=lambda name: True,
            weight_decay=0.01,
            lr_ratio=simple_lr_fun)

Z
zhaoyingli 已提交
335 336 337
        loss_ref = np.array(
            [4.8383293, 3.0854003, 1.33299, -0.418993, -2.171043])
        for i in range(5):
338 339
            a1 = linear1(a)
            out = linear2(a1)
Z
zhaoyingli 已提交
340
            out = paddle.mean(out)
341 342 343
            out.backward()
            adam.step()
            adam.clear_gradients()
Z
zhaoyingli 已提交
344
            np.testing.assert_allclose(out[0].numpy(), loss_ref[i], rtol=1e-6)
345 346 347

    def test_adamw_op(self):
        paddle.enable_static()
Z
zhaoyingli 已提交
348
        place = fluid.CUDAPlace(0)
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                x = fluid.data(name='x', shape=[None, 10], dtype='float32')
                y = fluid.data(name='y', shape=[None, 1], dtype='float32')

                fc1 = fluid.layers.fc(input=x, size=32, act=None)
                prediction = fluid.layers.fc(input=fc1, size=1, act=None)
                cost = fluid.layers.square_error_cost(input=prediction, label=y)
                avg_cost = fluid.layers.mean(cost)

                simple_lr_fun = partial(
                    simple_lr_setting, decay_rate=0.8, n_layers=2)

                beta1 = fluid.layers.create_global_var(
                    shape=[1], value=0.85, dtype='float32', persistable=True)
                beta2 = fluid.layers.create_global_var(
                    shape=[1], value=0.95, dtype='float32', persistable=True)
                betas = [beta1, beta2]
                opt = paddle.optimizer.AdamW(
                    learning_rate=1e-5,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01,
                    epsilon=1e-8,
                    lr_ratio=simple_lr_fun)
                opt.minimize(avg_cost)

        exe = fluid.Executor(place)
        exe.run(startup)
Z
zhaoyingli 已提交
380 381 382 383

        loss_ref = np.array(
            [0.36120513, 0.2720821, 0.67208904, 0.14607805, 0.24098626])
        for i in range(5):
384 385 386 387 388 389 390
            inputs = np.random.random(size=[8, 10]).astype('float32')
            outputs = np.random.random(size=[8, 1]).astype('float32')
            rets = exe.run(train_prog,
                           feed={"x": inputs,
                                 "y": outputs},
                           fetch_list=[avg_cost])
            assert rets[0] is not None
Z
zhaoyingli 已提交
391
            np.testing.assert_allclose(rets[0], loss_ref[i], rtol=1e-6)
392 393 394 395

        paddle.disable_static()


M
MRXLT 已提交
396 397
if __name__ == "__main__":
    unittest.main()