test_sgd_op.py 18.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Q
Qiao Longfei 已提交
17
import unittest
Q
qijun 已提交
18
import numpy as np
19
import paddle.fluid as fluid
20 21
import paddle.fluid.core as core
from paddle.fluid.op import Operator
22
from op_test import OpTest
J
Jiawei Wang 已提交
23
import paddle
Z
zyfncg 已提交
24
from paddle.fluid.framework import _test_eager_guard
Q
Qiao Longfei 已提交
25

W
WangXi 已提交
26 27
paddle.enable_static()

Q
Qiao Longfei 已提交
28

29
class TestSGDOp(OpTest):
30

Q
Qiao Longfei 已提交
31
    def setUp(self):
Q
qijun 已提交
32
        self.op_type = "sgd"
T
tensor-tang 已提交
33 34 35
        self.conf()
        w = np.random.random((self.h, self.w)).astype("float32")
        g = np.random.random((self.h, self.w)).astype("float32")
36
        lr = np.array([0.1]).astype("float32")
D
dangqingqing 已提交
37

38 39
        self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
        self.outputs = {'ParamOut': w - lr * g}
Q
Qiao Longfei 已提交
40

T
tensor-tang 已提交
41 42 43 44
    def conf(self):
        self.h = 102
        self.w = 105

Q
qijun 已提交
45 46 47
    def test_check_output(self):
        self.check_output()

Q
Qiao Longfei 已提交
48

T
tensor-tang 已提交
49
class TestSGDOpCase8X(TestSGDOp):
50

T
tensor-tang 已提交
51 52 53 54 55
    def conf(self):
        self.h = 10
        self.w = 64


Q
qijun 已提交
56
class TestSparseSGDOp(unittest.TestCase):
57

Q
qijun 已提交
58
    def check_with_place(self, place):
Q
qijun 已提交
59 60
        scope = core.Scope()

61
        # create and initialize Grad Variable
Q
qijun 已提交
62 63
        height = 10
        rows = [0, 4, 7]
T
tensor-tang 已提交
64
        self.conf()
Q
qiaolongfei 已提交
65 66 67 68

        grad_selected_rows = scope.var('Grad').get_selected_rows()
        grad_selected_rows.set_height(height)
        grad_selected_rows.set_rows(rows)
T
tensor-tang 已提交
69
        np_array = np.ones((len(rows), self.row_numel)).astype("float32")
Q
qiaolongfei 已提交
70 71 72 73 74 75 76 77
        np_array[0, 0] = 2.0
        np_array[2, 8] = 4.0

        grad_tensor = grad_selected_rows.get_tensor()
        grad_tensor.set(np_array, place)

        # create and initialize Param Variable
        param = scope.var('Param').get_tensor()
T
tensor-tang 已提交
78
        param_array = np.full((height, self.row_numel), 5.0).astype("float32")
Q
qiaolongfei 已提交
79 80 81 82 83 84 85 86
        param.set(param_array, place)

        # create and initialize LeraningRate Variable
        lr = scope.var('LearningRate').get_tensor()
        lr_array = np.full((1), 2.0).astype("float32")
        lr.set(lr_array, place)

        # create and run sgd operator
87 88 89 90 91
        sgd_op = Operator("sgd",
                          Param='Param',
                          Grad='Grad',
                          ParamOut='Param',
                          LearningRate='LearningRate')
Q
qiaolongfei 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        sgd_op.run(scope, place)

        # get and compare result
        result_array = np.array(param)

        # rows[0] = 0, 5.0 - 2.0 * 2.0
        self.assertAlmostEqual(1.0, result_array[rows[0], 0])
        # rows[0] = 0, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[0], 2])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[1, 0])
        # rows[1] = 4, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[1], 10])
        # 5.0 - 2.0 * 0.0
        self.assertAlmostEqual(5.0, result_array[5, 8])
        # rows[2] = 7, 5.0 - 2.0 * 1.0
        self.assertAlmostEqual(3.0, result_array[rows[2], 1])
        # rows[2] = 7, 5.0 - 2.0 * 4.0
        self.assertAlmostEqual(-3.0, result_array[rows[2], 8])

    def test_sparse_sgd(self):
        places = [core.CPUPlace()]
        if core.is_compiled_with_cuda():
            places.append(core.CUDAPlace(0))
        for place in places:
            self.check_with_place(place)

T
tensor-tang 已提交
119 120 121 122 123
    def conf(self):
        self.row_numel = 12


class TestSparseSGDOpCase8X(TestSparseSGDOp):
124

T
tensor-tang 已提交
125 126 127
    def conf(self):
        self.row_numel = 16

Q
qiaolongfei 已提交
128 129

class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
130

Q
qiaolongfei 已提交
131 132 133
    def check_with_place(self, place):
        scope = core.Scope()

Q
qiaolongfei 已提交
134
        row_width = 12
Q
qiaolongfei 已提交
135
        # create and initialize Grad Variable
Q
qiaolongfei 已提交
136 137
        grad_height = 10
        grad_rows = [0, 4, 7]
Q
qijun 已提交
138 139

        grad_selected_rows = scope.var('Grad').get_selected_rows()
Q
qiaolongfei 已提交
140 141 142 143 144
        grad_selected_rows.set_height(grad_height)
        grad_selected_rows.set_rows(grad_rows)
        grad_array = np.ones((len(grad_rows), row_width)).astype("float32")
        grad_array[0, 0] = 2.0
        grad_array[2, 8] = 4.0
Q
qijun 已提交
145

Q
qijun 已提交
146
        grad_tensor = grad_selected_rows.get_tensor()
Q
qiaolongfei 已提交
147
        grad_tensor.set(grad_array, place)
Q
qijun 已提交
148 149

        # create and initialize Param Variable
Q
qiaolongfei 已提交
150 151 152 153 154 155 156
        # create and initialize W Variable
        param_rows = [0, 1, 2, 3, 4, 5, 6, 7]

        # init Param
        w_selected_rows = scope.var('Param').get_selected_rows()
        w_selected_rows.set_height(len(param_rows))
        w_selected_rows.set_rows(param_rows)
157
        w_selected_rows.sync_index()
Q
qiaolongfei 已提交
158 159 160 161 162 163 164
        w_array = np.ones((len(param_rows), row_width)).astype("float32")
        for i in range(len(param_rows)):
            w_array[i] *= i
        w_tensor = w_selected_rows.get_tensor()
        w_tensor.set(w_array, place)

        w_before_optimize = np.array(w_tensor)
Q
qijun 已提交
165 166

        # create and initialize LeraningRate Variable
Q
qiaolongfei 已提交
167
        lr_value = 0.1
Q
qijun 已提交
168
        lr = scope.var('LearningRate').get_tensor()
Q
qiaolongfei 已提交
169
        lr_array = np.full((1), lr_value).astype("float32")
Q
qijun 已提交
170 171
        lr.set(lr_array, place)

Q
qiaolongfei 已提交
172 173 174
        # optimize with Python
        w_after_optimize = np.copy(w_before_optimize)
        for index, id in enumerate(grad_rows):
175 176
            w_after_optimize[
                id] = w_before_optimize[id] - lr_value * grad_array[index]
Q
qiaolongfei 已提交
177

Q
qijun 已提交
178
        # create and run sgd operator
179 180 181 182 183
        sgd_op = Operator("sgd",
                          Param='Param',
                          Grad='Grad',
                          ParamOut='Param',
                          LearningRate='LearningRate')
D
dzhwinter 已提交
184
        sgd_op.run(scope, place)
Q
qijun 已提交
185 186

        # get and compare result
Q
qiaolongfei 已提交
187 188
        result_array = np.array(w_tensor)
        assert (result_array == w_after_optimize).all()
Q
qijun 已提交
189

190
    def test_sparse_parameter_sgd(self):
Q
qijun 已提交
191
        places = [core.CPUPlace()]
192
        # do not support GPU kernel currently
Q
qijun 已提交
193 194 195
        for place in places:
            self.check_with_place(place)

Q
qijun 已提交
196

197
class TestSGDOpWithLargeInput(unittest.TestCase):
198

199
    def runTest(self):
200
        paddle.enable_static()
201
        data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64')
202 203 204
        label = fluid.layers.fill_constant(shape=[1, 150],
                                           value=0.5,
                                           dtype='float32')
205 206 207 208
        emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32')
        out = fluid.layers.l2_normalize(x=emb, axis=-1)

        cost = fluid.layers.square_error_cost(input=out, label=label)
209
        avg_cost = paddle.mean(cost)
210 211 212 213 214 215 216 217 218 219 220
        sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
        sgd_optimizer.minimize(avg_cost)

        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        compiled_prog = fluid.compiler.CompiledProgram(
            fluid.default_main_program())
        result = exe.run(compiled_prog, fetch_list=[avg_cost])


J
Jiawei Wang 已提交
221
class TestSGDV2(unittest.TestCase):
222

J
Jiawei Wang 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    def test_sgd_dygraph(self):
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear = paddle.nn.Linear(13, 5)
        # This can be any optimizer supported by dygraph.
        adam = paddle.optimizer.SGD(learning_rate=0.01,
                                    parameters=linear.parameters(),
                                    weight_decay=0.01)
        out = linear(a)
        out.backward()
        adam.step()
        adam.clear_gradients()

    def test_sgd(self):
238
        paddle.enable_static()
W
WangXi 已提交
239 240 241 242 243

        def check_sgd_optimizer(optimizer_attr):
            init_program = paddle.static.Program()
            program = paddle.static.Program()
            block = program.global_block()
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
            mul_x = block.create_parameter(dtype="float32",
                                           shape=[5, 10],
                                           lod_level=0,
                                           name="mul.x",
                                           optimize_attr=optimizer_attr)
            mul_y = block.create_var(dtype="float32",
                                     shape=[10, 8],
                                     lod_level=0,
                                     name="mul.y")
            mul_out = block.create_var(dtype="float32",
                                       shape=[5, 8],
                                       lod_level=0,
                                       name="mul.out")
            mean_out = block.create_var(dtype="float32",
                                        shape=[1],
                                        lod_level=0,
                                        name="mean.out")
            block.append_op(type="mul",
                            inputs={
                                "X": mul_x,
                                "Y": mul_y
                            },
                            outputs={"Out": mul_out},
                            attrs={"x_num_col_dims": 1})
            block.append_op(type="mean",
                            inputs={"X": mul_out},
                            outputs={"Out": mean_out})
W
WangXi 已提交
271 272 273 274 275 276 277 278 279 280 281
            sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            opts, _ = sgd_optimizer.minimize(mean_out, init_program)
            return opts

        opts = check_sgd_optimizer({'learning_rate': 1.1})
        self.assertEqual(len(opts), 2)
        self.assertEqual([op.type for op in opts], ["scale", "sgd"])

        opts = check_sgd_optimizer({'learning_rate': 1.0})
        self.assertEqual(len(opts), 1)
        self.assertEqual([op.type for op in opts], ["sgd"])
J
Jiawei Wang 已提交
282 283 284 285

    def test_raise_error(self):
        self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None)

W
WangXi 已提交
286
    def test_sgd_group_dygraph(self):
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        paddle.disable_static()
        value = np.arange(26).reshape(2, 13).astype("float32")
        a = paddle.to_tensor(value)
        linear_1 = paddle.nn.Linear(13, 5)
        linear_2 = paddle.nn.Linear(5, 3)
        # This can be any optimizer supported by dygraph.
        adam = paddle.optimizer.SGD(learning_rate=0.01,
                                    parameters=[{
                                        'params': linear_1.parameters()
                                    }, {
                                        'params': linear_2.parameters(),
                                        'weight_decay': 0.001,
                                        'learning_rate': 0.1
                                    }],
                                    weight_decay=0.01)
        out = linear_1(a)
        out = linear_2(out)
        out.backward()
        adam.step()
        adam.clear_gradients()

Z
zyfncg 已提交
308 309 310 311 312
    def test_eager(self):
        with _test_eager_guard():
            self.test_sgd_dygraph()
            self.test_sgd_group_dygraph()

313

314
class TestSGDMultiPrecision2_0(unittest.TestCase):
315

316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    def dygraph_sgd_mp(self, mp):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.optimizer.SGD(parameters=model.parameters(),
                                         multi_precision=mp)
        if mp == True:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if mp == True:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()
            else:
                output = model(input)
                loss = paddle.mean(output)
                optimizer.step()
                optimizer.clear_grad()

        return output, model.parameters()

    def static_sgd_mp(self, mp):
        paddle.enable_static()
        paddle.seed(10)
        np.random.seed(10)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.optimizer.SGD(multi_precision=mp)

        if mp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False)
        with paddle.static.program_guard(train_program, startup_program):
            if mp:
363 364 365
                data = paddle.static.data(shape=[2, 2],
                                          name='X',
                                          dtype='float16')
366
            else:
367 368 369
                data = paddle.static.data(shape=[2, 2],
                                          name='X',
                                          dtype='float32')
370
            hidden = paddle.static.nn.fc(x=data, size=10)
371
            loss = paddle.mean(hidden)
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
            optimizer.minimize(loss)
        exe.run(startup_program)

        if mp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            loss_data, = exe.run(train_program,
                                 feed={"X": x},
                                 fetch_list=[loss.name])
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True)
        output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False)
394 395 396 397
        np.testing.assert_allclose(output1_dy.astype('float32').numpy(),
                                   output2_dy.astype('float32').numpy(),
                                   rtol=1e-05,
                                   atol=0.1)
398
        for idx in range(len(params1_dy)):
399 400 401 402 403
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1)
404 405 406 407
        "Test static mode"
        output1_st = self.static_sgd_mp(mp=True)
        output2_st = self.static_sgd_mp(mp=False)
        for idx in range(len(output1_st)):
408 409 410 411
            np.testing.assert_allclose(output1_st[idx].astype('float32'),
                                       output2_st[idx].astype('float32'),
                                       rtol=1e-05,
                                       atol=0.1)
412 413 414


class TestSGDMultiPrecision1_0(unittest.TestCase):
415

416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    def dygraph_sgd_mp(self, mp):
        paddle.disable_static()
        paddle.seed(10)
        paddle.set_device('gpu')
        input = paddle.randn((2, 2))
        model = paddle.nn.Linear(2, 2)
        optimizer = paddle.fluid.optimizer.SGD(
            learning_rate=0.001,
            parameter_list=model.parameters(),
            multi_precision=mp)
        if mp == True:
            model = paddle.amp.decorate(models=model, level='O2')
            scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for idx in range(5):
            if mp == True:
                with paddle.amp.auto_cast(level='O2'):
                    output = model(input)
                    loss = paddle.mean(output)
                scaled = scaler.scale(loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                optimizer.clear_gradients()
            else:
                output = model(input)
                loss = paddle.mean(output)
                optimizer.minimize(loss)
                optimizer.clear_gradients()

        return output, model.parameters()

    def static_sgd_mp(self, mp):
        paddle.enable_static()
        paddle.seed(10)
        np.random.seed(10)
        exe = paddle.static.Executor('gpu')
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001,
                                               multi_precision=mp)

        if mp:
            optimizer = paddle.static.amp.decorate(
                optimizer,
                init_loss_scaling=128.0,
                use_dynamic_loss_scaling=True,
                use_pure_fp16=True,
                use_fp16_guard=False)
        with paddle.static.program_guard(train_program, startup_program):
            if mp:
466 467 468
                data = paddle.static.data(shape=[2, 2],
                                          name='X',
                                          dtype='float16')
469
            else:
470 471 472
                data = paddle.static.data(shape=[2, 2],
                                          name='X',
                                          dtype='float32')
473
            hidden = paddle.static.nn.fc(x=data, size=10)
474
            loss = paddle.mean(hidden)
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
            optimizer.minimize(loss)
        exe.run(startup_program)

        if mp:
            optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
            x = np.random.random(size=(2, 2)).astype('float16')
        else:
            x = np.random.random(size=(2, 2)).astype('float32')
        out = []
        for idx in range(5):
            loss_data, = exe.run(train_program,
                                 feed={"X": x},
                                 fetch_list=[loss.name])
            out.append(loss_data)
        return out

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        "Test dygraph mode"
        output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True)
        output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False)
497 498 499 500
        np.testing.assert_allclose(output1_dy.astype('float32').numpy(),
                                   output2_dy.astype('float32').numpy(),
                                   rtol=1e-05,
                                   atol=0.1)
501
        for idx in range(len(params1_dy)):
502 503 504 505 506
            np.testing.assert_allclose(
                params1_dy[idx].astype('float32').numpy(),
                params2_dy[idx].astype('float32').numpy(),
                rtol=1e-05,
                atol=0.1)
507 508 509 510
        "Test static mode"
        output1_st = self.static_sgd_mp(mp=True)
        output2_st = self.static_sgd_mp(mp=False)
        for idx in range(len(output1_st)):
511 512 513 514
            np.testing.assert_allclose(output1_st[idx].astype('float32'),
                                       output2_st[idx].astype('float32'),
                                       rtol=1e-05,
                                       atol=0.1)
515 516


Q
Qiao Longfei 已提交
517 518
if __name__ == "__main__":
    unittest.main()