test_adamax_op.py 15.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
18
from op_test import OpTest
19

20 21
import paddle

22 23 24

class TestAdamaxOp1(OpTest):
    def setUp(self):
25
        '''Test Adamax Operator with supplied attributes'''
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
        self.op_type = "adamax"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The infinity norm is positive
        inf_norm = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.002
        beta1 = 0.78
        beta2 = 0.899
        epsilon = 1e-5
        beta1_pow = beta1**10

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment': moment,
            'InfNorm': inf_norm,
            'LearningRate': np.array([learning_rate]).astype("float32"),
45
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
46 47 48 49
        }

        self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}

50
        param_out, moment_out, inf_norm_out = adamax_step(
51 52
            self.inputs, self.attrs
        )
53 54 55 56

        self.outputs = {
            'ParamOut': param_out,
            'MomentOut': moment_out,
57
            'InfNormOut': inf_norm_out,
58 59 60 61 62 63 64
        }

    def test_check_output(self):
        self.check_output()


class TestAdamaxOp2(OpTest):
65
    '''Test Adamax Operator with default attributes'''
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

    def setUp(self):
        self.op_type = "adamax"
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The infinity norm is positive
        inf_norm = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.002
        beta1 = 0.9
        beta2 = 0.999
        epsilon = 1e-8
        beta1_pow = beta1**8

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment': moment,
            'InfNorm': inf_norm,
            'LearningRate': np.array([learning_rate]).astype("float32"),
87
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
88 89 90
        }

        attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}
91
        param_out, moment_out, inf_norm_out = adamax_step(self.inputs, attrs)
92 93 94 95

        self.outputs = {
            'ParamOut': param_out,
            'MomentOut': moment_out,
96
            'InfNormOut': inf_norm_out,
97 98 99 100 101 102 103 104
        }

    def test_check_output(self):
        self.check_output()


class TestAdamaxOpMultipleSteps(OpTest):
    def setUp(self):
105
        '''Test Adamax Operator with supplied attributes'''
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        self.op_type = "adamax"
        self.num_steps = 10

        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        moment = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The infinity norm is positive
        inf_norm = np.random.random((102, 105)).astype("float32")

        learning_rate = 0.002
        beta1 = 0.8
        beta2 = 0.99
        epsilon = 1e-5
        beta1_pow = 1

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'Moment': moment,
            'InfNorm': inf_norm,
            'LearningRate': np.array([learning_rate]).astype("float32"),
127
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
128 129 130 131 132 133
        }

        self.attrs = {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon}

    def test_check_output(self):
        for _ in range(self.num_steps):
134
            param_out, moment_out, inf_norm_out = adamax_step(
135 136
                self.inputs, self.attrs
            )
137 138 139 140

            self.outputs = {
                'ParamOut': param_out,
                'MomentOut': moment_out,
141
                'InfNormOut': inf_norm_out,
142 143 144 145 146 147 148 149 150
            }

            # Verify output for this step
            self.check_output()

            # Output of this step becomes input for next step
            self.inputs['Param'] = param_out
            self.inputs['Moment'] = moment_out
            self.inputs['InfNorm'] = inf_norm_out
151 152 153

            # Update Beta1 Power accumulator for next step
            self.inputs['Beta1Pow'] *= self.attrs['beta1']
154 155

            # Randomize gradient for next step
156 157 158
            self.inputs['Grad'] = np.random.uniform(-1, 1, (102, 105)).astype(
                "float32"
            )
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181


def adamax_step(inputs, attributes):
    '''
    Simulate one step of the adamax optimizer
    :param inputs: dict of inputs
    :param attributes: dict of attributes
    :return tuple: tuple of output param, moment, inf_norm and
    beta1 power accumulator
    '''
    param = inputs['Param']
    grad = inputs['Grad']
    moment = inputs['Moment']
    inf_norm = inputs['InfNorm']
    lr = inputs['LearningRate']
    beta1_pow = inputs['Beta1Pow']

    beta1 = attributes['beta1']
    beta2 = attributes['beta2']
    epsilon = attributes['epsilon']

    moment_out = beta1 * moment + (1 - beta1) * grad
    inf_norm_out = np.maximum(beta2 * inf_norm + epsilon, np.abs(grad))
182
    lr_t = lr / (1 - beta1_pow)
183 184
    param_out = param - lr_t * np.divide(moment_out, inf_norm_out)

185
    return param_out, moment_out, inf_norm_out
186 187


M
MRXLT 已提交
188 189 190
class TestAdamaxOpV2(unittest.TestCase):
    def test_adamax_op_invalid_input(self):
        import paddle
191

M
MRXLT 已提交
192 193 194
        paddle.disable_static()
        linear = paddle.nn.Linear(10, 10)
        with self.assertRaises(ValueError):
195 196 197
            adam = paddle.optimizer.Adamax(
                0.1, beta1=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
198
        with self.assertRaises(ValueError):
199 200 201
            adam = paddle.optimizer.Adamax(
                0.1, beta2=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
202
        with self.assertRaises(ValueError):
203 204 205
            adam = paddle.optimizer.Adamax(
                0.1, epsilon=-1, parameters=linear.parameters()
            )
M
MRXLT 已提交
206 207


208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
class TestAdamaxOpMultiPrecison(unittest.TestCase):
    def _test_adamax_op_dygraph_place_amp(self, place, use_amp=False):
        import paddle

        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device(place)
        input = paddle.randn((5, 5))

        model = paddle.nn.Linear(5, 5)
        optimizer = paddle.optimizer.Adamax(
            0.1, beta1=0.1, parameters=model.parameters()
        )
        optimizer._multi_precision = use_amp
        for idx in range(2):
            if place == 'gpu' and use_amp:
                model = paddle.amp.decorate(models=model, level='O2')
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

            if place == 'gpu' and use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.step(optimizer)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
        paddle.enable_static()

    def _get_places(self):
        import paddle

        places = ['cpu']
        if paddle.is_compiled_with_cuda():
            places.append('gpu')
        return places

    def test_main(self):
        for place in self._get_places():
            use_amp_list = [True, False]
            for use_amp in use_amp_list:
                self._test_adamax_op_dygraph_place_amp(place, use_amp)


class TestAdamaxMultiPrecision2_0(unittest.TestCase):
    def dygraph_adamax_mp(self, mp, use_amp):
        paddle.disable_static()
        paddle.seed(100)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.optimizer.Adamax(0.5, parameters=model.parameters())
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()

        return output, model.parameters()

    def static_adamax_mp(self, mp, use_amp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.optimizer.Adamax(0.1)
        optimizer._multi_precision = mp
        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_adamax_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_adamax_mp(use_amp=False, mp=False)
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_adamax_mp(use_amp=True, mp=True)
        output2_st = self.static_adamax_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


class TestAdamaxMultiPrecision1_0(unittest.TestCase):
    def dygraph_adamax_mp(self, use_amp, mp):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.fluid.optimizer.Adamax(
            learning_rate=0.001, parameter_list=model.parameters()
        )
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_gradients()
            else:
                output = model(input)
                loss = paddle.mean(output)
                optimizer.minimize(loss)
                optimizer.clear_gradients()

        return output, model.parameters()

    def static_adamax_mp(self, use_amp, mp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.fluid.optimizer.Adamax(learning_rate=0.001)
        optimizer._multi_precision = mp
        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_adamax_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_adamax_mp(use_amp=False, mp=False)
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_adamax_mp(use_amp=True, mp=True)
        output2_st = self.static_adamax_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


471 472
if __name__ == "__main__":
    unittest.main()