test_bf16_utils.py 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import paddle.fluid as fluid
A
arlesniak 已提交
17
import paddle.static.amp as amp
18 19 20 21 22 23 24
from paddle.fluid import core
import paddle

paddle.enable_static()


class AMPTest(unittest.TestCase):
25

26 27 28 29 30 31 32 33 34 35 36 37
    def setUp(self):
        self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list)
        self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list)
        self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list)
        self.amp_lists_ = None

    def tearDown(self):
        self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list)
        self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list)
        self.assertEqual(self.amp_lists_.gray_list, self.gray_list)

    def test_amp_lists(self):
A
arlesniak 已提交
38
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16()
39 40 41 42 43 44

    def test_amp_lists_1(self):
        # 1. w={'exp}, b=None
        self.bf16_list.add('exp')
        self.fp32_list.remove('exp')

A
arlesniak 已提交
45
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'exp'})
46 47 48 49 50 51

    def test_amp_lists_2(self):
        # 2. w={'tanh'}, b=None
        self.fp32_list.remove('tanh')
        self.bf16_list.add('tanh')

A
arlesniak 已提交
52
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
53 54 55 56 57

    def test_amp_lists_3(self):
        # 3. w={'lstm'}, b=None
        self.bf16_list.add('lstm')

A
arlesniak 已提交
58
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
59 60

    def test_amp_lists_4(self):
A
arlesniak 已提交
61 62 63
        # 4. w=None, b={'matmul_v2'}
        self.bf16_list.remove('matmul_v2')
        self.fp32_list.add('matmul_v2')
64

A
arlesniak 已提交
65
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
A
arlesniak 已提交
66
            custom_fp32_list={'matmul_v2'})
67 68

    def test_amp_lists_5(self):
A
arlesniak 已提交
69 70 71
        # 5. w=None, b={'matmul_v2'}
        self.fp32_list.add('matmul_v2')
        self.bf16_list.remove('matmul_v2')
72

A
arlesniak 已提交
73
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
A
arlesniak 已提交
74
            custom_fp32_list={'matmul_v2'})
75 76 77 78 79

    def test_amp_lists_6(self):
        # 6. w=None, b={'lstm'}
        self.fp32_list.add('lstm')

A
arlesniak 已提交
80
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
81 82 83 84 85 86
            custom_fp32_list={'lstm'})

    def test_amp_lists_7(self):
        self.fp32_list.add('reshape2')
        self.gray_list.remove('reshape2')

A
arlesniak 已提交
87
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
88 89 90 91 92 93
            custom_fp32_list={'reshape2'})

    def test_amp_list_8(self):
        self.bf16_list.add('reshape2')
        self.gray_list.remove('reshape2')

A
arlesniak 已提交
94
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
95 96 97 98
            custom_bf16_list={'reshape2'})


class AMPTest2(unittest.TestCase):
99

100 101 102
    def test_amp_lists_(self):
        # 7. w={'lstm'} b={'lstm'}
        # raise ValueError
A
arlesniak 已提交
103
        self.assertRaises(ValueError, amp.bf16.AutoMixedPrecisionListsBF16,
104 105 106 107 108 109 110 111 112 113 114 115 116 117
                          {'lstm'}, {'lstm'})

    def test_find_op_index(self):
        block = fluid.default_main_program().global_block()
        op_desc = core.OpDesc()
        idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc)
        assert (idx == -1)

    def test_is_in_fp32_varnames(self):
        block = fluid.default_main_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
        var3 = block.create_var(name="Z", shape=[3], dtype='float32')
118 119 120 121 122 123
        op1 = block.append_op(type="abs",
                              inputs={"X": [var1]},
                              outputs={"Out": [var2]})
        op2 = block.append_op(type="abs",
                              inputs={"X": [var2]},
                              outputs={"Out": [var3]})
A
arlesniak 已提交
124
        amp_lists_1 = amp.bf16.AutoMixedPrecisionListsBF16(
125 126
            custom_fp32_varnames={'X'})
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
A
arlesniak 已提交
127
        amp_lists_2 = amp.bf16.AutoMixedPrecisionListsBF16(
128 129 130 131 132 133 134 135 136 137 138
            custom_fp32_varnames={'Y'})
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)

    def test_find_true_post_op(self):

        block = fluid.default_main_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
        var3 = block.create_var(name="Z", shape=[3], dtype='float32')
139 140 141 142 143 144
        op1 = block.append_op(type="abs",
                              inputs={"X": [var1]},
                              outputs={"Out": [var2]})
        op2 = block.append_op(type="abs",
                              inputs={"X": [var2]},
                              outputs={"Out": [var3]})
145 146 147
        res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
        assert (res == [op2])

148 149 150 151 152 153 154
    def test_find_true_post_op_with_search_all(self):
        program = fluid.Program()
        block = program.current_block()
        startup_block = fluid.default_startup_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        inititializer_op = startup_block._prepend_op(type="fill_constant",
                                                     outputs={"Out": var1},
                                                     attrs={
                                                         "shape": var1.shape,
                                                         "dtype": var1.dtype,
                                                         "value": 1.0
                                                     })

        op1 = block.append_op(type="abs",
                              inputs={"X": [var1]},
                              outputs={"Out": [var2]})
        result = amp.bf16.amp_utils.find_true_post_op(block.ops,
                                                      inititializer_op,
                                                      "X",
                                                      search_all=False)
170
        assert (len(result) == 0)
171 172 173 174
        result = amp.bf16.amp_utils.find_true_post_op(block.ops,
                                                      inititializer_op,
                                                      "X",
                                                      search_all=True)
175 176
        assert (result == [op1])

177 178 179

if __name__ == '__main__':
    unittest.main()