test_adam_op.py 17.9 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17 18
import unittest
import numpy as np
19
from op_test import OpTest
20 21
from paddle.fluid import core
from paddle.fluid.op import Operator
22
import paddle.fluid as fluid
M
MRXLT 已提交
23
import paddle
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55


class TestAdamOp1(OpTest):
    def setUp(self):
        '''Test Adam Op with supplied attributes
        '''
        self.op_type = "adam"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.004
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}

56 57
        param_out, moment1_out, \
            moment2_out = adam_step(self.inputs, self.attrs)
58 59 60 61

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
A
Aurelius84 已提交
62 63 64
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        }

    def test_check_output(self):
        self.check_output()


class TestAdamOp2(OpTest):
    def setUp(self):
        '''Test Adam Op with supplied attributes
        '''
        self.op_type = "adam"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.001
        beta1 = 0.9
        beta2 = 0.999
        epsilon = 1e-8
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
        }

        attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}

101 102
        param_out, moment1_out, \
            moment2_out = adam_step(self.inputs, attributes)
103 104 105 106

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
A
Aurelius84 已提交
107 108 109
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        }

    def test_check_output(self):
        self.check_output()


class TestAdamOpMultipleSteps(OpTest):
    def setUp(self):
        '''Test Adam Operator with supplied attributes
        '''
        self.op_type = "adam"
        self.num_steps = 10

        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.001
A
Aurelius84 已提交
130 131
        self.beta1 = 0.9
        self.beta2 = 0.999
132
        epsilon = 1e-8
A
Aurelius84 已提交
133 134
        self.beta1_pow = self.beta1**10
        self.beta2_pow = self.beta2**10
135 136 137 138 139 140 141

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
A
Aurelius84 已提交
142 143
            'Beta1Pow': np.array([self.beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([self.beta2_pow]).astype("float32")
144 145
        }

A
Aurelius84 已提交
146 147 148 149 150
        self.attrs = {
            'epsilon': epsilon,
            'beta1': self.beta1,
            'beta2': self.beta2
        }
151 152 153

    def test_check_output(self):
        for _ in range(self.num_steps):
154 155
            param_out, moment1_out, \
                moment2_out = adam_step(self.inputs, self.attrs)
156

A
Aurelius84 已提交
157 158
            beta1_pow_out = self.inputs['Beta1Pow'] * self.beta1
            beta2_pow_out = self.inputs['Beta2Pow'] * self.beta2
159 160 161
            self.outputs = {
                'Moment1Out': moment1_out,
                'Moment2Out': moment2_out,
A
Aurelius84 已提交
162 163 164
                'ParamOut': param_out,
                'Beta1PowOut': beta1_pow_out,
                'Beta2PowOut': beta2_pow_out
165 166 167 168 169 170 171 172 173
            }

            # Verify output for this step
            self.check_output()

            # Output of this step becomes input for next step
            self.inputs['Param'] = param_out
            self.inputs['Moment1'] = moment1_out
            self.inputs['Moment2'] = moment2_out
174 175

            # Update powers of Beta1 and Beta2 for next time step
A
Aurelius84 已提交
176 177
            self.inputs['Beta1Pow'] = beta1_pow_out
            self.inputs['Beta2Pow'] = beta2_pow_out
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201

            # Randomize gradient for next step
            self.inputs['Grad'] = np.random.uniform(
                -1, 1, (102, 105)).astype("float32")


def adam_step(inputs, attributes):
    '''
    Simulate one step of the adam optimizer
    :param inputs: dict of inputs
    :param attributes: dict of attributes
    :return tuple: tuple of output param, moment1, moment2,
    beta1 power accumulator and beta2 power accumulator
    '''
    param = inputs['Param']
    grad = inputs['Grad']
    moment1 = inputs['Moment1']
    moment2 = inputs['Moment2']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']
    beta2_pow = inputs['Beta2Pow']

    epsilon = attributes['epsilon']

202 203 204 205 206 207 208 209 210
    if 'beta1' in attributes:
        beta1 = attributes['beta1']
    else:
        beta1 = inputs['Beta1Tensor'][0]
    if 'beta2' in attributes:
        beta2 = attributes['beta2']
    else:
        beta2 = inputs['Beta2Tensor'][0]

211 212
    moment1_out = beta1 * moment1 + (1 - beta1) * grad
    moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
213
    lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
214
    param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
215
    return param_out, moment1_out, moment2_out
216 217


Q
Qiao Longfei 已提交
218
def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad,
Q
Qiao Longfei 已提交
219
                     lazy_mode):
T
wip  
typhoonzero 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    '''
    Simulate one step of the adam optimizer
    :param inputs: dict of inputs
    :param attributes: dict of attributes
    :return tuple: tuple of output param, moment1, moment2,
    beta1 power accumulator and beta2 power accumulator
    '''
    param = inputs['Param']
    # grad = inputs['Grad']
    moment1 = inputs['Moment1']
    moment2 = inputs['Moment2']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']
    beta2_pow = inputs['Beta2Pow']

    beta1 = attributes['beta1']
    beta2 = attributes['beta2']
    epsilon = attributes['epsilon']

T
typhoonzero 已提交
239 240 241
    moment1_out = np.zeros(shape=[height, row_numel])
    moment2_out = np.zeros(shape=[height, row_numel])
    param_out = np.zeros(shape=[height, row_numel])
T
wip  
typhoonzero 已提交
242

Q
Qiao Longfei 已提交
243
    def update_row(row_id, update_value):
T
wip  
typhoonzero 已提交
244
        moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
Q
Qiao Longfei 已提交
245
                                                         ) * update_value
T
wip  
typhoonzero 已提交
246
        moment2_out[row_id] = beta2 * moment2[row_id] + (
Q
Qiao Longfei 已提交
247
            1 - beta2) * np.square(update_value)
T
wip  
typhoonzero 已提交
248
        lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
T
typhoonzero 已提交
249 250
        param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / (
            np.sqrt(moment2_out[row_id]) + epsilon))
Q
Qiao Longfei 已提交
251 252 253 254 255 256 257 258 259 260 261

    if lazy_mode:
        for idx, row_id in enumerate(rows):
            update_row(row_id, np_grad[idx])
    else:
        for row_id in range(param_out.shape[0]):
            update_value = np.zeros(np_grad[0].shape).astype("float32")
            if row_id in rows:
                update_value = np_grad[rows.index(row_id)]
            update_row(row_id, update_value)

T
wip  
typhoonzero 已提交
262 263 264 265
    return param_out, moment1_out, moment2_out


class TestSparseAdamOp(unittest.TestCase):
Q
Qiao Longfei 已提交
266
    def setup(self, scope, place, lazy_mode):
T
wip  
typhoonzero 已提交
267 268 269
        beta1 = 0.78
        beta2 = 0.836
        epsilon = 1e-4
A
Aurelius84 已提交
270 271
        beta1_pow = np.array([beta1**10]).astype("float32")
        beta2_pow = np.array([beta2**10]).astype("float32")
T
wip  
typhoonzero 已提交
272 273 274

        height = 10
        rows = [0, 4, 7]
T
typhoonzero 已提交
275
        self.rows = rows
T
wip  
typhoonzero 已提交
276
        row_numel = 12
T
typhoonzero 已提交
277
        self.row_numel = row_numel
T
wip  
typhoonzero 已提交
278
        self.dense_inputs = {
Q
Qiao Longfei 已提交
279 280 281
            "Param": np.full((height, row_numel), 5.0).astype("float32"),
            "Moment1": np.full((height, row_numel), 5.0).astype("float32"),
            "Moment2": np.full((height, row_numel), 5.0).astype("float32"),
A
Aurelius84 已提交
282 283
            'Beta1Pow': beta1_pow,
            'Beta2Pow': beta2_pow,
T
wip  
typhoonzero 已提交
284 285
            "LearningRate": np.full((1), 2.0).astype("float32")
        }
Q
Qiao Longfei 已提交
286
        self.init_output = np.full((height, row_numel), 0.0).astype("float32")
287 288 289 290 291 292
        self.attrs = {
            'epsilon': epsilon,
            'beta1': beta1,
            'beta2': beta2,
            'min_row_size_to_use_multithread': 2
        }
T
wip  
typhoonzero 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305

        grad_selected_rows = scope.var('Grad').get_selected_rows()
        grad_selected_rows.set_height(height)
        grad_selected_rows.set_rows(rows)
        np_array = np.ones((len(rows), row_numel)).astype("float32")
        np_array[0, 0] = 2.0
        np_array[2, 8] = 4.0

        grad_tensor = grad_selected_rows.get_tensor()
        grad_tensor.set(np_array, place)

        self.sparse_inputs = ["Grad"]

Q
Qiao Longfei 已提交
306 307
        param_out, mom1, mom2 = adam_step_sparse(self.dense_inputs, self.attrs,
                                                 height, rows, row_numel,
Q
Qiao Longfei 已提交
308
                                                 np_array, lazy_mode)
T
wip  
typhoonzero 已提交
309
        self.outputs = {
T
typhoonzero 已提交
310
            "ParamOut": param_out,
T
wip  
typhoonzero 已提交
311
            "Moment1Out": mom1,
A
Aurelius84 已提交
312 313 314
            "Moment2Out": mom2,
            'Beta1PowOut': beta1_pow * beta1,
            'Beta2PowOut': beta2_pow * beta2
T
wip  
typhoonzero 已提交
315 316
        }

Q
Qiao Longfei 已提交
317
    def check_with_place(self, place, lazy_mode):
T
wip  
typhoonzero 已提交
318
        scope = core.Scope()
Q
Qiao Longfei 已提交
319
        self.setup(scope, place, lazy_mode)
T
wip  
typhoonzero 已提交
320 321

        op_args = dict()
Q
Qiao Longfei 已提交
322
        op_args['lazy_mode'] = lazy_mode
323
        for key, np_array in self.dense_inputs.items():
T
wip  
typhoonzero 已提交
324 325 326 327 328
            var = scope.var(key).get_tensor()
            var.set(np_array, place)
            op_args[key] = key
        for s in self.sparse_inputs:
            op_args[s] = s
T
typhoonzero 已提交
329 330
        for s in self.outputs:
            var = scope.var(s).get_tensor()
Q
Qiao Longfei 已提交
331
            var.set(self.init_output, place)
T
typhoonzero 已提交
332
            op_args[s] = s
T
wip  
typhoonzero 已提交
333 334 335 336
        for k in self.attrs:
            op_args[k] = self.attrs[k]

        # create and run sgd operator
T
typhoonzero 已提交
337 338
        adam_op = Operator("adam", **op_args)
        adam_op.run(scope, place)
T
wip  
typhoonzero 已提交
339

340
        for key, np_array in self.outputs.items():
T
wip  
typhoonzero 已提交
341 342
            out_var = scope.var(key).get_tensor()
            actual = np.array(out_var)
T
typhoonzero 已提交
343 344
            actual = actual.reshape([actual.size])
            np_array = np_array.reshape([np_array.size])
Q
Qiao Longfei 已提交
345 346 347

            for i in range(np_array.size):
                self.assertLess((actual[i] - np_array[i]), 0.00001)
T
wip  
typhoonzero 已提交
348

Q
Qiao Longfei 已提交
349
    def test_sparse_adam(self):
T
wip  
typhoonzero 已提交
350
        places = [core.CPUPlace()]
351
        if core.is_compiled_with_cuda():
T
wip  
typhoonzero 已提交
352 353
            places.append(core.CUDAPlace(0))
        for place in places:
Q
Qiao Longfei 已提交
354 355
            for lazy_mode in (True, False):
                self.check_with_place(place, lazy_mode)
T
wip  
typhoonzero 已提交
356 357


358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
class TestAdamOpBetaVariable(OpTest):
    def setUp(self):
        '''Test Adam Op with beta as Variable
        '''
        self.op_type = "adam"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The second moment is positive
        moment2 = np.random.random((102, 105)).astype("float32")
        beta1 = 0.85
        beta2 = 0.95

        learning_rate = 0.001
        epsilon = 1e-8
        beta1_pow = beta1**10
        beta2_pow = beta2**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment1': moment1,
            'Moment2': moment2,
            'LearningRate': np.array([learning_rate]).astype("float32"),
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
            'Beta2Pow': np.array([beta2_pow]).astype("float32"),
            "Beta1Tensor": np.array([beta1]).astype("float32"),
            "Beta2Tensor": np.array([beta2]).astype("float32"),
        }

        attributes = {'epsilon': epsilon}

        param_out, moment1_out, \
            moment2_out = adam_step(self.inputs, attributes)

        self.outputs = {
            'Moment1Out': moment1_out,
            'Moment2Out': moment2_out,
A
Aurelius84 已提交
396 397 398
            'ParamOut': param_out,
            'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
            'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
399 400 401 402 403 404
        }

    def test_check_output(self):
        self.check_output()


M
MRXLT 已提交
405 406 407
class TestAdamOpV2(unittest.TestCase):
    def test_adam_op(self):
        place = fluid.CPUPlace()
408
        shape = [2, 3, 8, 8]
M
MRXLT 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
        exe = fluid.Executor(place)
        train_prog = fluid.Program()
        startup = fluid.Program()
        with fluid.program_guard(train_prog, startup):
            with fluid.unique_name.guard():
                data = fluid.data(name="data", shape=shape)
                conv = fluid.layers.conv2d(data, 8, 3)
                loss = fluid.layers.reduce_mean(conv)

                beta1 = fluid.layers.create_global_var(
                    shape=[1], value=0.85, dtype='float32', persistable=True)
                beta2 = fluid.layers.create_global_var(
                    shape=[1], value=0.95, dtype='float32', persistable=True)
                betas = [beta1, beta2]
                opt = paddle.optimizer.Adam(
                    learning_rate=1e-5,
                    beta1=beta1,
                    beta2=beta2,
                    weight_decay=0.01,
                    epsilon=1e-8)
                opt.minimize(loss)

        exe.run(startup)
        data_np = np.random.random(shape).astype('float32')
        rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
        assert rets[0] is not None

    def test_adam_op_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = fluid.dygraph.to_variable(value)
        linear = fluid.Linear(13, 5, dtype="float32")

        adam = paddle.optimizer.Adam(
            learning_rate=0.01, parameters=linear.parameters())
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_adam_op_with_state_dict(self):

        import paddle
        paddle.disable_static()
        emb = paddle.nn.Embedding([10, 10])

        adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters())
        state_dict = adam.state_dict()
        adam.set_state_dict(state_dict)

459 460 461
        #learning_rate is _LRScheduler
        learning_rate = paddle.optimizer.CosineAnnealingLR(
            learning_rate=0.1, T_max=10)
M
MRXLT 已提交
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        adam = paddle.optimizer.Adam(
            learning_rate=learning_rate,
            weight_decay=fluid.regularizer.L2Decay(0.001),
            parameters=emb.parameters())
        lr = adam.get_lr()
        state_dict = adam.state_dict()
        adam.set_state_dict(state_dict)

        #leanrning_rate is Tensor
        with self.assertRaises(TypeError):
            learning_rate = np.array([0.01]).astype("float32")
            learning_rate = paddle.to_tensor(learning_rate)
            adam = paddle.optimizer.Adam(
                learning_rate=learning_rate, parameters=emb.parameters())

        params = adam.get_opti_var_name_list()
        assert (params is not None)

    def test_adam_with_grad_clip(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = fluid.dygraph.to_variable(value)
        linear = fluid.Linear(13, 5, dtype="float32")
        clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
        adam = paddle.optimizer.Adam(
            0.1, parameters=linear.parameters(), grad_clip=clip)
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_adam_op_with_set_lr(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())

        lr = 0.01
        adam.set_lr(lr)
        cur_lr = adam.get_lr()
        assert (lr == cur_lr)
        with self.assertRaises(TypeError):
503 504 505
            lr_var = paddle.create_global_var(
                shape=[1], value=lr, dtype='float32')
            adam.set_lr(lr_var)
506

M
MRXLT 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519
    def test_adam_op_invalid_input(self):
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.Adam(
                0.1, beta1=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.Adam(
                0.1, beta2=-1, parameters=linear.parameters())
        with self.assertRaises(ValueError):
            adam = paddle.optimizer.Adam(
                0.1, epsilon=-1, parameters=linear.parameters())

520

521 522
if __name__ == "__main__":
    unittest.main()