test_bf16_utils.py 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import paddle.fluid as fluid
A
arlesniak 已提交
17
import paddle.static.amp as amp
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
from paddle.fluid import core
import paddle

paddle.enable_static()


class AMPTest(unittest.TestCase):
    def setUp(self):
        self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list)
        self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list)
        self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list)
        self.amp_lists_ = None

    def tearDown(self):
        self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list)
        self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list)
        self.assertEqual(self.amp_lists_.gray_list, self.gray_list)

    def test_amp_lists(self):
A
arlesniak 已提交
37
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16()
38 39 40 41 42 43

    def test_amp_lists_1(self):
        # 1. w={'exp}, b=None
        self.bf16_list.add('exp')
        self.fp32_list.remove('exp')

A
arlesniak 已提交
44
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'exp'})
45 46 47 48 49 50

    def test_amp_lists_2(self):
        # 2. w={'tanh'}, b=None
        self.fp32_list.remove('tanh')
        self.bf16_list.add('tanh')

A
arlesniak 已提交
51
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
52 53 54 55 56

    def test_amp_lists_3(self):
        # 3. w={'lstm'}, b=None
        self.bf16_list.add('lstm')

A
arlesniak 已提交
57
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
58 59

    def test_amp_lists_4(self):
A
arlesniak 已提交
60 61 62
        # 4. w=None, b={'matmul_v2'}
        self.bf16_list.remove('matmul_v2')
        self.fp32_list.add('matmul_v2')
63

A
arlesniak 已提交
64
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
65 66
            custom_fp32_list={'matmul_v2'}
        )
67 68

    def test_amp_lists_5(self):
A
arlesniak 已提交
69 70 71
        # 5. w=None, b={'matmul_v2'}
        self.fp32_list.add('matmul_v2')
        self.bf16_list.remove('matmul_v2')
72

A
arlesniak 已提交
73
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
74 75
            custom_fp32_list={'matmul_v2'}
        )
76 77 78 79 80

    def test_amp_lists_6(self):
        # 6. w=None, b={'lstm'}
        self.fp32_list.add('lstm')

A
arlesniak 已提交
81
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
82 83
            custom_fp32_list={'lstm'}
        )
84 85 86 87 88

    def test_amp_lists_7(self):
        self.fp32_list.add('reshape2')
        self.gray_list.remove('reshape2')

A
arlesniak 已提交
89
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
90 91
            custom_fp32_list={'reshape2'}
        )
92 93 94 95 96

    def test_amp_list_8(self):
        self.bf16_list.add('reshape2')
        self.gray_list.remove('reshape2')

A
arlesniak 已提交
97
        self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
98 99
            custom_bf16_list={'reshape2'}
        )
100 101 102 103 104 105


class AMPTest2(unittest.TestCase):
    def test_amp_lists_(self):
        # 7. w={'lstm'} b={'lstm'}
        # raise ValueError
106 107 108
        self.assertRaises(
            ValueError, amp.bf16.AutoMixedPrecisionListsBF16, {'lstm'}, {'lstm'}
        )
109 110 111 112 113

    def test_find_op_index(self):
        block = fluid.default_main_program().global_block()
        op_desc = core.OpDesc()
        idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc)
114
        assert idx == -1
115 116 117 118 119 120 121

    def test_is_in_fp32_varnames(self):
        block = fluid.default_main_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
        var3 = block.create_var(name="Z", shape=[3], dtype='float32')
122 123 124 125 126 127
        op1 = block.append_op(
            type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}
        )
        op2 = block.append_op(
            type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}
        )
A
arlesniak 已提交
128
        amp_lists_1 = amp.bf16.AutoMixedPrecisionListsBF16(
129 130
            custom_fp32_varnames={'X'}
        )
131
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
A
arlesniak 已提交
132
        amp_lists_2 = amp.bf16.AutoMixedPrecisionListsBF16(
133 134
            custom_fp32_varnames={'Y'}
        )
135 136 137 138 139 140 141 142 143 144
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
        assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)

    def test_find_true_post_op(self):

        block = fluid.default_main_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
        var3 = block.create_var(name="Z", shape=[3], dtype='float32')
145 146 147 148 149 150
        op1 = block.append_op(
            type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}
        )
        op2 = block.append_op(
            type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}
        )
151
        res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
152
        assert res == [op2]
153

154 155 156 157 158 159 160
    def test_find_true_post_op_with_search_all(self):
        program = fluid.Program()
        block = program.current_block()
        startup_block = fluid.default_startup_program().global_block()

        var1 = block.create_var(name="X", shape=[3], dtype='float32')
        var2 = block.create_var(name="Y", shape=[3], dtype='float32')
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        inititializer_op = startup_block._prepend_op(
            type="fill_constant",
            outputs={"Out": var1},
            attrs={"shape": var1.shape, "dtype": var1.dtype, "value": 1.0},
        )

        op1 = block.append_op(
            type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}
        )
        result = amp.bf16.amp_utils.find_true_post_op(
            block.ops, inititializer_op, "X", search_all=False
        )
        assert len(result) == 0
        result = amp.bf16.amp_utils.find_true_post_op(
            block.ops, inititializer_op, "X", search_all=True
        )
        assert result == [op1]
178

179 180 181

if __name__ == '__main__':
    unittest.main()