test_adadelta_op.py 16.5 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
姜永久 已提交
18
from eager_op_test import OpTest
19

J
Jiawei Wang 已提交
20 21
import paddle
import paddle.fluid as fluid
22 23


姜永久 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def adadelta_wrapper(
    Param,
    Grad,
    AvgSquaredGrad,
    AvgSquaredUpdate,
    master_weight=None,
    rho=0.95,
    epsilon=1e-6,
):
    paddle._C_ops.adadelta_(
        Param,
        Grad,
        AvgSquaredGrad,
        AvgSquaredUpdate,
        None,
        rho,
        epsilon,
        False,
    )
    return Param, AvgSquaredGrad, AvgSquaredUpdate


46 47 48
class TestAdadeltaOp1(OpTest):
    def setUp(self):
        self.op_type = "adadelta"
姜永久 已提交
49 50
        self.python_api = adadelta_wrapper
        self.python_out_sig = ['Out']
51 52 53 54 55 56 57 58 59 60 61 62 63 64
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The squared gradient is positive
        avg_squared_grad = np.random.random((102, 105)).astype("float32")
        # The squared update is positive
        avg_squared_update = np.random.random((102, 105)).astype("float32")

        rho = 0.95
        epsilon = 1e-6

        self.inputs = {
            'Param': param,
            'Grad': grad,
            'AvgSquaredGrad': avg_squared_grad,
65
            'AvgSquaredUpdate': avg_squared_update,
66 67 68 69
        }

        self.attrs = {'rho': rho, 'epsilon': epsilon}

70 71 72
        avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * np.square(
            grad
        )
73 74
        update = -np.multiply(
            np.sqrt(
75 76 77 78 79 80
                np.divide(
                    avg_squared_update + epsilon, avg_squared_grad_out + epsilon
                )
            ),
            grad,
        )
81

82 83 84
        avg_squared_update_out = rho * avg_squared_update + (
            1 - rho
        ) * np.square(update)
85 86 87 88 89 90

        param_out = param + update

        self.outputs = {
            'ParamOut': param_out,
            'AvgSquaredGradOut': avg_squared_grad_out,
91
            'AvgSquaredUpdateOut': avg_squared_update_out,
92 93 94 95 96 97 98
        }

    def test_check_output(self):
        self.check_output()


class TestAdadeltaOp2(OpTest):
99
    '''Test Adadelta op with default attribute values'''
100 101 102

    def setUp(self):
        self.op_type = "adadelta"
姜永久 已提交
103 104
        self.python_api = adadelta_wrapper
        self.python_out_sig = ['Out']
105 106 107 108 109 110 111 112 113 114
        param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
        # The squared gradient is positive
        avg_squared_grad = np.random.random((102, 105)).astype("float32")
        # The squared update is positive
        avg_squared_update = np.random.random((102, 105)).astype("float32")

        rho = 0.95
        epsilon = 1e-6

姜永久 已提交
115 116
        self.attrs = {'rho': rho, 'epsilon': epsilon}

117 118 119 120
        self.inputs = {
            'Param': param,
            'Grad': grad,
            'AvgSquaredGrad': avg_squared_grad,
121
            'AvgSquaredUpdate': avg_squared_update,
122 123
        }

124 125 126
        avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * np.square(
            grad
        )
127 128
        update = -np.multiply(
            np.sqrt(
129 130 131 132 133 134
                np.divide(
                    avg_squared_update + epsilon, avg_squared_grad_out + epsilon
                )
            ),
            grad,
        )
135

136 137 138
        avg_squared_update_out = rho * avg_squared_update + (
            1 - rho
        ) * np.square(update)
139 140 141 142 143 144

        param_out = param + update

        self.outputs = {
            'ParamOut': param_out,
            'AvgSquaredGradOut': avg_squared_grad_out,
145
            'AvgSquaredUpdateOut': avg_squared_update_out,
146 147 148 149 150 151
        }

    def test_check_output(self):
        self.check_output()


J
Jiawei Wang 已提交
152 153 154 155 156 157 158
class TestAdadeltaV2(unittest.TestCase):
    def test_adadelta_dygraph(self):
        paddle.disable_static(paddle.CPUPlace())
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear = paddle.nn.Linear(13, 5)
        # This can be any optimizer supported by dygraph.
159 160 161 162 163
        adam = paddle.optimizer.Adadelta(
            learning_rate=0.01,
            parameters=linear.parameters(),
            weight_decay=0.01,
        )
J
Jiawei Wang 已提交
164 165 166 167 168 169
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_adadelta(self):
170
        paddle.enable_static()
J
Jiawei Wang 已提交
171 172 173
        place = fluid.CPUPlace()
        main = fluid.Program()
        with fluid.program_guard(main):
G
GGBond8488 已提交
174 175
            x = paddle.static.data(name='x', shape=[-1, 13], dtype='float32')
            y = paddle.static.data(name='y', shape=[-1, 1], dtype='float32')
C
Charles-hit 已提交
176
            y_predict = paddle.static.nn.fc(x, size=1)
177 178 179
            cost = paddle.nn.functional.square_error_cost(
                input=y_predict, label=y
            )
180
            avg_cost = paddle.mean(cost)
J
Jiawei Wang 已提交
181 182 183 184 185

            rms_optimizer = paddle.optimizer.Adadelta(learning_rate=0.1)
            rms_optimizer.minimize(avg_cost)

            fetch_list = [avg_cost]
186 187 188
            train_reader = paddle.batch(
                paddle.dataset.uci_housing.train(), batch_size=1
            )
J
Jiawei Wang 已提交
189 190 191 192 193 194 195 196
            feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            for data in train_reader():
                exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)

    def test_raise_error(self):
        self.assertRaises(ValueError, paddle.optimizer.Adadelta, None)
197 198 199 200 201 202 203 204 205
        self.assertRaises(
            ValueError, paddle.optimizer.Adadelta, learning_rate=0.1, rho=None
        )
        self.assertRaises(
            ValueError,
            paddle.optimizer.Adadelta,
            learning_rate=0.1,
            epsilon=None,
        )
J
Jiawei Wang 已提交
206 207


208 209 210 211 212 213 214 215
class TestAdadeltaV2Group(TestAdadeltaV2):
    def test_adadelta_dygraph(self):
        paddle.disable_static(paddle.CPUPlace())
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 5)
        # This can be any optimizer supported by dygraph.
216 217 218 219 220 221 222 223 224 225 226
        adam = paddle.optimizer.Adadelta(
            learning_rate=0.01,
            parameters=[
                {'params': linear_1.parameters()},
                {
                    'params': linear_2.parameters(),
                    'weight_decay': 0.001,
                },
            ],
            weight_decay=0.1,
        )
227 228 229 230 231 232 233
        out = linear_1(a)
        out = linear_2(out)
        out.backward()
        adam.step()
        adam.clear_gradients()


234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
class TestAdadeltaOpMultiPrecison(unittest.TestCase):
    def _test_adadelta_op_dygraph_place_amp(self, place, use_amp=False):
        import paddle

        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device(place)
        input = paddle.randn((5, 5))

        model = paddle.nn.Linear(5, 5)
        optimizer = paddle.optimizer.Adadelta(
            learning_rate=0.01,
            parameters=model.parameters(),
            weight_decay=0.1,
        )

        optimizer._multi_precision = use_amp

        for idx in range(2):
            if place == 'gpu' and use_amp:
                model = paddle.amp.decorate(models=model, level='O2')
                scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

            if place == 'gpu' and use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.step(optimizer)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
        paddle.enable_static()

    def _get_places(self):
        import paddle

        places = ['cpu']
        if paddle.is_compiled_with_cuda():
            places.append('gpu')
        return places

    def test_main(self):
        for place in self._get_places():
            use_amp_list = [True, False]
            for use_amp in use_amp_list:
                self._test_adadelta_op_dygraph_place_amp(place, use_amp)


class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
    def dygraph_adadelta_mp(self, mp, use_amp):
        paddle.disable_static()
        paddle.seed(100)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.optimizer.Adadelta(
            0.5, parameters=model.parameters()
        )
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()

        return output, model.parameters()

    def static_adadelta_mp(self, mp, use_amp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.optimizer.Adadelta(0.1)
        optimizer._multi_precision = mp

        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_adadelta_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_adadelta_mp(
            use_amp=False, mp=False
        )
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_adadelta_mp(use_amp=True, mp=True)
        output2_st = self.static_adadelta_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
    def dygraph_adadelta_mp(self, use_amp, mp):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.fluid.optimizer.Adadelta(
            learning_rate=0.001,
            parameter_list=model.parameters(),
        )
        optimizer._multi_precision = mp
        if use_amp:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if use_amp:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_gradients()
            else:
                output = model(input)
                loss = paddle.mean(output)
                optimizer.minimize(loss)
                optimizer.clear_gradients()

        return output, model.parameters()

    def static_adadelta_mp(self, use_amp, mp):
        paddle.enable_static()
        paddle.seed(100)
        np.random.seed(100)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
        optimizer._multi_precision = mp

        if use_amp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False,
            )
        with paddle.static.program_guard(train_program, startup_program):
            if use_amp:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float16'
                )
            else:
                data = paddle.static.data(
                    shape=[2, 2], name='X', dtype='float32'
                )
            hidden = paddle.static.nn.fc(x=data, size=10)
            loss = paddle.mean(hidden)
            optimizer.minimize(loss)
        exe.run(startup_program)

        if use_amp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            (loss_data,) = exe.run(
                train_program, feed={"X": x}, fetch_list=[loss.name]
            )
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_adadelta_mp(use_amp=True, mp=True)
        output2_dy, params2_dy = self.dygraph_adadelta_mp(
            use_amp=False, mp=False
        )
        np.testing.assert_allclose(
            output1_dy.astype('float32').numpy(),
            output2_dy.astype('float32').numpy(),
            rtol=1e-05,
            atol=0.1,
        )
        for idx in range(len(params1_dy)):
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1,
            )
        "Test static mode"
        output1_st = self.static_adadelta_mp(use_amp=True, mp=True)
        output2_st = self.static_adadelta_mp(use_amp=False, mp=False)
        for idx in range(len(output1_st)):
            np.testing.assert_allclose(
                output1_st[idx].astype('float32'),
                output2_st[idx].astype('float32'),
                rtol=1e-05,
                atol=0.1,
            )


510 511
if __name__ == "__main__":
    unittest.main()