test_gradient_clip.py 17.2 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
C
chengduo 已提交
2 3 4 5 6
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
C
chengduo 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
22 23
import six
from fake_reader import fake_imdb_reader
C
chengduo 已提交
24

W
WangXi 已提交
25 26
paddle.enable_static()

C
chengduo 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

def bow_net(data,
            label,
            dict_dim,
            emb_dim=128,
            hid_dim=128,
            hid_dim2=96,
            class_dim=2):
    """
    BOW net
    This model is from https://github.com/PaddlePaddle/models:
    fluid/PaddleNLP/text_classification/nets.py
    """
    emb = fluid.layers.embedding(
        input=data, is_sparse=True, size=[dict_dim, emb_dim])
    bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
    bow_tanh = fluid.layers.tanh(bow)
    fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
    fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
    prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
    cost = fluid.layers.cross_entropy(input=prediction, label=label)
    avg_cost = fluid.layers.mean(x=cost)

    return avg_cost


class TestGradientClip(unittest.TestCase):
    def setUp(self):
55
        self.word_dict_len = 5147
C
chengduo 已提交
56
        self.BATCH_SIZE = 2
57 58
        reader = fake_imdb_reader(self.word_dict_len, self.BATCH_SIZE * 100)
        self.train_data = paddle.batch(reader, batch_size=self.BATCH_SIZE)
zhouweiwei2014's avatar
zhouweiwei2014 已提交
59
        self.clip_gradient = lambda x: None
60 61 62 63
        self.init()

    def init(self):
        pass
C
chengduo 已提交
64 65

    def get_places(self):
66
        places = [fluid.CPUPlace()]
C
chengduo 已提交
67
        if core.is_compiled_with_cuda():
68
            places.append(fluid.CUDAPlace(0))
C
chengduo 已提交
69 70
        return places

71 72 73
    def check_clip_result(self, out, out_clip):
        pass

74
    def check_gradient_clip(self, place, dtype='float32'):
75 76
        prog = fluid.Program()
        startup_program = fluid.Program()
C
chengduo 已提交
77 78
        with fluid.program_guard(
                main_program=prog, startup_program=startup_program):
79 80
            image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
            label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
81 82 83 84 85
            if dtype != 'float32':
                image_cast = paddle.cast(image, dtype)
                hidden = fluid.layers.fc(input=image_cast, size=32, act='relu')
            else:
                hidden = fluid.layers.fc(input=image, size=32, act='relu')
86
            predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
C
chengduo 已提交
87 88 89 90 91 92 93 94 95 96

            cost = fluid.layers.cross_entropy(input=predict, label=label)
            avg_cost = fluid.layers.mean(cost)

        prog_clip = prog.clone()
        avg_cost_clip = prog_clip.block(0).var(avg_cost.name)

        p_g = fluid.backward.append_backward(loss=avg_cost)
        p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)

97 98
        p_g = sorted(p_g, key=lambda x: x[0].name)
        p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name)
99 100
        with fluid.program_guard(
                main_program=prog_clip, startup_program=startup_program):
101
            p_g_clip = self.clip_gradient(p_g_clip)
C
chengduo 已提交
102 103 104 105

        grad_list = [elem[1] for elem in p_g]
        grad_clip_list = [elem[1] for elem in p_g_clip]

106
        train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=3)
C
chengduo 已提交
107 108 109 110
        exe = fluid.Executor(place)
        feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
        exe.run(startup_program)

111 112 113 114 115 116
        data = next(train_reader())
        out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
        out_clip = exe.run(prog_clip,
                           feed=feeder.feed(data),
                           fetch_list=grad_clip_list)
        self.check_clip_result(out, out_clip)
C
chengduo 已提交
117 118

    def check_sparse_gradient_clip(self, place):
119 120
        prog = fluid.Program()
        startup_program = fluid.Program()
C
chengduo 已提交
121 122
        with fluid.program_guard(
                main_program=prog, startup_program=startup_program):
123 124 125
            data = fluid.data(
                name="words", shape=[-1, 1], dtype="int64", lod_level=1)
            label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
126
            cost = bow_net(data, label, self.word_dict_len)
C
chengduo 已提交
127

128
            self.backward_and_optimize(cost)
C
chengduo 已提交
129 130 131 132 133 134 135 136 137 138

        exe = fluid.Executor(place)
        feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
        exe.run(startup_program)

        data = next(self.train_data())
        val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
        self.assertEqual((1, ), val.shape)
        self.assertFalse(np.isnan(val))

139
    def backward_and_optimize(self, cost):
140 141 142 143 144 145 146 147 148 149
        pass


class TestGradientClipByGlobalNorm(TestGradientClip):
    def init(self):
        self.clip_norm = 0.2

    def check_clip_result(self, out, out_clip):
        global_norm = 0
        for v in out:
W
WangXi 已提交
150
            global_norm += np.sum(np.square(v))
151 152 153 154 155 156 157 158 159 160
        global_norm = np.sqrt(global_norm)
        scale = self.clip_norm / np.maximum(self.clip_norm, global_norm)
        res = []
        for i in range(len(out)):
            out[i] = scale * out[i]

        for u, v in zip(out, out_clip):
            self.assertTrue(
                np.allclose(
                    a=u, b=v, rtol=1e-5, atol=1e-8),
W
WangXi 已提交
161 162
                "gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}".
                format(u, v, u - v))
163 164 165 166 167 168 169 170 171 172 173

    # test whether the ouput is right when use 'set_gradient_clip'
    def test_old_gradient_clip(self):
        def func(params_grads):
            clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
            fluid.clip.set_gradient_clip(clip)
            return fluid.clip.append_gradient_clip_ops(params_grads)

        self.clip_gradient = func
        self.check_gradient_clip(fluid.CPUPlace())

174
    # test whether the ouput is right when use grad_clip
175 176 177 178
    def test_new_gradient_clip(self):
        def func(params_grads):
            clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
            return clip(params_grads)
C
chengduo 已提交
179

180 181 182
        self.clip_gradient = func
        self.check_gradient_clip(fluid.CPUPlace())

183 184 185 186 187 188 189 190 191
    # test whether the ouput is right when use grad_clip under float64
    def test_new_gradient_clip_fp64(self):
        def func(params_grads):
            clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
            return clip(params_grads)

        self.clip_gradient = func
        self.check_gradient_clip(fluid.CPUPlace(), "float64")

192 193 194
    # invoke 'set_gradient_clip' in a wrong order
    def test_wrong_API_order(self):
        def backward_func(cost):
195
            clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
196
            fluid.clip.set_gradient_clip(clip)
197 198 199 200
            sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01,
                                                grad_clip=clip)
            # if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
            sgd_optimizer.minimize(cost)
201 202 203 204
            # 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
            fluid.clip.set_gradient_clip(clip)

        self.backward_and_optimize = backward_func
C
chengduo 已提交
205 206 207
        for place in self.get_places():
            self.check_sparse_gradient_clip(place)

208 209
    # raise typeError
    def test_tpyeError(self):
210
        # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
211
        with self.assertRaises(TypeError):
212 213
            sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
                                                grad_clip="test")
214

215 216 217 218
    # if grad is None or not need clip
    def test_none_grad_fp32(self):
        ops = self._test_none_grad_helper("float32")
        self.assertListEqual(ops, [
219
            'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
220 221 222 223 224 225 226 227
            'fill_constant', 'elementwise_max', 'elementwise_div',
            'elementwise_mul', 'elementwise_mul'
        ])

    def test_none_grad_fp16(self):
        ops = self._test_none_grad_helper("float16")
        self.assertListEqual(ops, [
            'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
228 229
            'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
            'cast', 'elementwise_mul', 'cast', 'elementwise_mul'
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        ])

    def _test_none_grad_helper(self, dtype):
        prog = fluid.Program()
        startup_program = fluid.Program()
        with fluid.program_guard(
                main_program=prog, startup_program=startup_program):
            clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
            x = fluid.default_main_program().global_block().create_parameter(
                name="x", shape=[2, 3], dtype=dtype)
            y = fluid.default_main_program().global_block().create_parameter(
                name="y", shape=[2, 3], dtype=dtype)

            # (x, None) should not be returned
            params_grads = [(x, None), (x, y), (y, x)]
            params_grads = clip(params_grads)
            self.assertTrue(
                len(params_grads) == 2,
                "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
            )

            ops = [op.type for op in x.block.ops]
        return ops

254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

class TestGradientClipByNorm(TestGradientClip):
    def init(self):
        self.clip_norm = 0.2

    def check_clip_result(self, out, out_clip):
        for u, v in zip(out, out_clip):
            norm = np.sqrt(np.sum(np.power(u, 2)))
            scale = self.clip_norm / np.maximum(self.clip_norm, norm)
            u = u * scale
            self.assertTrue(
                np.allclose(
                    a=u, b=v, rtol=1e-5, atol=1e-8),
                "gradient clip by norm has wrong results!")

269
    # test whether the ouput is right when use grad_clip
270
    def test_gradient_clip(self):
zhouweiwei2014's avatar
zhouweiwei2014 已提交
271 272 273 274 275
        def func(params_grads):
            clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
            return clip(params_grads)

        self.clip_gradient = func
276 277 278 279
        self.check_gradient_clip(fluid.CPUPlace())

    # if grad is None or not need clip
    def test_none_grad(self):
280
        clip = fluid.clip.GradientClipByNorm(self.clip_norm)
281
        x = fluid.default_main_program().global_block().create_parameter(
282
            name="x", shape=[2, 3], dtype="float32", need_clip=False)
283
        y = fluid.default_main_program().global_block().create_parameter(
284
            name="y", shape=[2, 3], dtype="float32", need_clip=False)
285 286 287 288 289 290

        # (x, None) should not be returned
        params_grads = [(x, None), (x, y)]
        params_grads = clip(params_grads)
        self.assertTrue(
            len(clip(params_grads)) == 1,
291
            "ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!"
292 293 294
        )
        self.assertTrue(
            params_grads[0][1].name == 'y',
295
            "ClipGradByNorm: grad should not be clipped when filtered out!")
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312


class TestGradientClipByValue(TestGradientClip):
    def init(self):
        self.max = 0.2
        self.min = 0.1

    def check_clip_result(self, out, out_clip):
        for i, v in enumerate(out):
            out[i] = np.clip(v, self.min, self.max)
        for u, v in zip(out, out_clip):
            u = np.clip(u, self.min, self.max)
            self.assertTrue(
                np.allclose(
                    a=u, b=v, rtol=1e-6, atol=1e-8),
                "gradient clip by value has wrong results!")

313
    # test whether the ouput is right when use grad_clip
314
    def test_gradient_clip(self):
zhouweiwei2014's avatar
zhouweiwei2014 已提交
315 316 317 318 319
        def func(params_grads):
            clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
            return clip(params_grads)

        self.clip_gradient = func
320 321 322 323
        self.check_gradient_clip(fluid.CPUPlace())

    # if grad is None or not need clip
    def test_none_grad(self):
324
        clip = fluid.clip.GradientClipByValue(self.max, self.min)
325
        x = fluid.default_main_program().global_block().create_parameter(
326
            name="x", shape=[2, 3], dtype="float32", need_clip=False)
327
        y = fluid.default_main_program().global_block().create_parameter(
328
            name="y", shape=[2, 3], dtype="float32", need_clip=False)
329 330 331 332 333 334

        # (x, None) should not be returned
        params_grads = [(x, None), (x, y)]
        params_grads = clip(params_grads)
        self.assertTrue(
            len(clip(params_grads)) == 1,
335
            "ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!"
336 337 338
        )
        self.assertTrue(
            params_grads[0][1].name == 'y',
339
            "ClipGradByValue: grad should not be clipped when filtered out!")
340 341 342 343 344 345 346 347 348 349 350 351


class TestDygraphGradientClip(unittest.TestCase):
    def test_gradient_clip(self):
        with fluid.dygraph.guard():
            linear = fluid.dygraph.Linear(5, 5)
            inputs = fluid.layers.uniform_random(
                [16, 5], min=-10, max=10).astype('float32')
            out = linear(fluid.dygraph.to_variable(inputs))
            loss = fluid.layers.reduce_mean(out)
            loss.backward()
            sgd_optimizer = fluid.optimizer.SGD(
352 353 354
                learning_rate=0.0,
                parameter_list=linear.parameters(),
                grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
355 356 357 358 359 360 361 362 363 364
            self.check_clip_result(loss, sgd_optimizer)

    def check_clip_result(self, loss, optimizer):
        pass


class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
    def setUp(self):
        self.clip_norm = 0.8
        self.clip1 = fluid.clip.GradientClipByGlobalNorm(
365
            clip_norm=self.clip_norm)
366 367 368 369 370 371 372 373 374 375 376
        self.clip2 = fluid.clip.GradientClipByGlobalNorm(
            clip_norm=self.clip_norm)

    def check_clip_result(self, loss, optimizer):
        # if grad is None
        x = fluid.dygraph.to_variable(
            np.array([2, 3]).astype("float32"), name="x")
        y = fluid.dygraph.to_variable(
            np.array([3, 4]).astype("float32"), name="y")
        assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
        # get params and grads from network
377
        opt, params_grads = optimizer.minimize(loss)
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
        _, grads = zip(*params_grads)
        params_grads = self.clip2(params_grads)
        _, grads_clip = zip(*params_grads)

        global_norm = 0
        for u in grads:
            u = u.numpy()
            global_norm += np.sum(np.power(u, 2))
        global_norm = np.sqrt(global_norm)

        global_norm_clip = 0
        for v in grads_clip:
            v = v.numpy()
            global_norm_clip += np.sum(np.power(v, 2))
        global_norm_clip = np.sqrt(global_norm_clip)

        a = np.minimum(global_norm, self.clip_norm)
        b = global_norm_clip
        self.assertTrue(
            np.isclose(
                a=a, b=b, rtol=1e-6, atol=1e-8),
            "gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f"
            % (a, b))


class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
    def setUp(self):
        self.clip_norm = 0.8
406
        self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
407 408 409 410 411 412 413

    def check_clip_result(self, loss, optimizer):
        # if grad is None
        x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
        assert len(self.clip([(x, None)])) == 0
        # get params and grads from network
        self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)])
414
        opt, params_grads = optimizer.minimize(loss)
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
        _, grads = zip(*params_grads)
        params_grads = self.clip(params_grads)
        _, grads_clip = zip(*params_grads)

        for u, v in zip(grads, grads_clip):
            u = u.numpy()
            v = v.numpy()
            a = np.sqrt(np.sum(np.power(u, 2)))
            a = np.minimum(a, self.clip_norm)
            b = np.sqrt(np.sum(np.power(v, 2)))
            self.assertTrue(
                np.isclose(
                    a=a, b=b, rtol=1e-6, atol=1e-8),
                "gradient clip by norm has wrong results, expetcd:%f, but recieved:%f"
                % (a, b))


class TestDygraphGradientClipByValue(TestDygraphGradientClip):
    def setUp(self):
        self.max = 0.2
        self.min = 0.1
436
        self.clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
437 438 439 440 441 442

    def check_clip_result(self, loss, optimizer):
        # if grad is None
        x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
        assert len(self.clip([(x, None)])) == 0
        # get params and grads from network
443
        opt, params_grads = optimizer.minimize(loss)
444 445 446 447 448 449 450 451 452 453 454
        _, grads = zip(*params_grads)
        params_grads = self.clip(params_grads)
        _, grads_clip = zip(*params_grads)
        for u, v in zip(grads, grads_clip):
            u = np.clip(u.numpy(), self.min, self.max)
            v = v.numpy()
            self.assertTrue(
                np.allclose(
                    a=u, b=v, rtol=1e-6, atol=1e-8),
                "gradient clip by value has wrong results!")

C
chengduo 已提交
455 456 457

if __name__ == '__main__':
    unittest.main()