test_model_cast_to_bf16.py 10.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
16
import struct
17
import unittest
18

19
import numpy as np
20
from amp_base_models import AmpTestBase, build_add_model, build_embedding_model
21 22 23

import paddle
from paddle import fluid
24
from paddle.fluid import core
25
from paddle.static import amp
26 27 28 29

paddle.enable_static()


30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
def copy_bits_from_float_to_uint16(f):
    return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


def convert_float_to_uint16(in_list):
    if in_list.dtype == np.float32:
        new_output = []
        for x in np.nditer(in_list):
            new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
        new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
        return new_output
    else:
        return in_list


A
arlesniak 已提交
45 46 47 48 49
def convert_uint16_to_float(in_list):
    if in_list.dtype == np.uint16:
        in_list = np.asarray(in_list)
        out = np.vectorize(
            lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
50 51
            otypes=[np.float32],
        )(in_list.flat)
A
arlesniak 已提交
52 53 54 55 56 57 58 59
        return np.reshape(out, in_list.shape)
    else:
        return in_list


cutf = convert_uint16_to_float


60 61 62
@unittest.skipIf(
    not core.supports_bfloat16(), "place does not support BF16 evaluation"
)
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
class TestModelCastBF16(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.seed = 111

    @classmethod
    def tearDownClass(cls):
        pass

    @contextlib.contextmanager
    def static_graph(self):
        with self.scope_prog_guard():
            paddle.seed(self.seed)
            paddle.framework.random._manual_program_seed(self.seed)
            yield

    @contextlib.contextmanager
    def scope_prog_guard(self):
        prog = fluid.Program()
        startup_prog = fluid.Program()
        scope = fluid.core.Scope()
        with fluid.scope_guard(scope):
            with fluid.program_guard(prog, startup_prog):
                yield

88 89 90
    def get_static_graph_result(
        self, feed, fetch_list, amp_fun, with_lod=False, startup_prog=None
    ):
91
        exe = fluid.Executor(core.CPUPlace())
92 93 94 95 96
        exe.run(
            fluid.default_startup_program()
            if startup_prog is None
            else startup_prog
        )
97 98
        prog = fluid.default_main_program()
        if amp_fun is not None:
99 100 101 102
            if startup_prog is not None:
                amp_fun(prog, startup_prog)
            else:
                amp_fun(prog)
103 104 105
        return exe.run(
            prog, feed=feed, fetch_list=fetch_list, return_numpy=(not with_lod)
        )
106

107
    def _graph_common(self, _amp_fun, startup_prog=None):
108 109 110 111
        size = 3
        n = np.ones([size, size], dtype='float32') * 3.2
        nn = np.ones([size, size], dtype='float32') * -2.7

A
arlesniak 已提交
112 113
        n_bf16 = amp.bf16.convert_float_to_uint16(n)
        nn_bf16 = amp.bf16.convert_float_to_uint16(nn)
114 115

        with self.static_graph():
G
GGBond8488 已提交
116 117
            t_bf16 = paddle.static.data(
                name='t_bf16', shape=[-1, size, size], dtype='int32'
118
            )
G
GGBond8488 已提交
119 120 121
            t_bf16.desc.set_need_check_feed(False)
            tt_bf16 = paddle.static.data(
                name='tt_bf16', shape=[-1, size, size], dtype='int32'
122
            )
G
GGBond8488 已提交
123 124 125 126 127 128 129 130 131
            tt_bf16.desc.set_need_check_feed(False)
            t = paddle.static.data(
                name='t', shape=[-1, size, size], dtype='float32'
            )
            t.desc.set_need_check_feed(False)
            tt = paddle.static.data(
                name='tt', shape=[-1, size, size], dtype='float32'
            )
            tt.desc.set_need_check_feed(False)
132

133 134
            ret = paddle.add(t, tt)
            ret = paddle.multiply(ret, t)
135
            ret = paddle.reshape(ret, [0, 0])
136

A
arlesniak 已提交
137
            with amp.bf16.bf16_guard():
138 139
                ret_bf16 = paddle.add(t_bf16, tt_bf16)
                ret_bf16 = paddle.multiply(ret_bf16, t_bf16)
140
                ret_bf16 = paddle.reshape(ret_bf16, [0, 0])
141

A
arlesniak 已提交
142
            with amp.bf16.bf16_guard():
143 144
                ret_fp32bf16 = paddle.add(t, tt)
                ret_fp32bf16 = paddle.multiply(ret_fp32bf16, t)
145
                ret_fp32bf16 = paddle.reshape(ret_fp32bf16, [0, 0])
146

147 148 149 150 151
            (
                static_ret_bf16,
                static_ret,
                ret_fp32bf16,
            ) = self.get_static_graph_result(
152 153 154 155 156 157 158
                feed={
                    't': n,
                    'tt': nn,
                    't_bf16': n_bf16,
                    'tt_bf16': nn_bf16,
                },
                fetch_list=[ret_bf16, ret, ret_fp32bf16],
A
arlesniak 已提交
159
                amp_fun=_amp_fun,
160 161
                startup_prog=startup_prog,
            )
162

163 164 165 166 167 168
        np.testing.assert_allclose(
            cutf(static_ret_bf16), cutf(static_ret), rtol=0.01
        )
        np.testing.assert_allclose(
            cutf(static_ret_bf16), cutf(ret_fp32bf16), rtol=0.01
        )
169 170

        with self.static_graph():
G
GGBond8488 已提交
171 172 173 174 175 176 177 178
            t = paddle.static.data(
                name='t', shape=[-1, size, size], dtype='float32'
            )
            t.desc.set_need_check_feed(False)
            tt = paddle.static.data(
                name='tt', shape=[-1, size, size], dtype='float32'
            )
            tt.desc.set_need_check_feed(False)
179

A
arlesniak 已提交
180
            with amp.bf16.bf16_guard():
181
                ret = paddle.add(t, tt)
182 183
                ret = paddle.reshape(ret, [0, 0])
                ret = paddle.nn.functional.elu(ret)
184 185
                ret = paddle.multiply(ret, t)
            ret = paddle.add(ret, tt)
186

187 188 189 190 191 192 193 194 195
            static_ret_bf16 = self.get_static_graph_result(
                feed={'t': n, 'tt': nn},
                fetch_list=[ret],
                amp_fun=_amp_fun,
                startup_prog=startup_prog,
            )
        self.assertTrue(
            static_ret_bf16, np.ones([size, size], dtype='float32') * -1.1
        )
196

A
arlesniak 已提交
197
    def test_graph_rewrite(self):
198 199 200 201 202 203 204 205 206
        self._graph_common(
            lambda prog: amp.bf16.rewrite_program_bf16(
                prog,
                amp.bf16.AutoMixedPrecisionListsBF16(
                    custom_bf16_list={'elementwise_add'},
                    custom_fp32_varnames={'elementwise_add_0.tmp_0'},
                ),
            )
        )
A
arlesniak 已提交
207 208

    def test_graph_cast(self):
209 210 211 212 213 214
        self._graph_common(
            lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
                prog,
                startup_prog,
                amp.bf16.AutoMixedPrecisionListsBF16(
                    custom_bf16_list={'elementwise_add'},
215 216 217 218 219 220
                    custom_fp32_list={'elementwise_mul'},
                ),
                use_bf16_guard=True,
            ),
            startup_prog=fluid.default_startup_program(),
        )
A
arlesniak 已提交
221

222

223
class TestProgramBF16(AmpTestBase):
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    def _check_optimizer(self, program, expected_num_mp):
        optimizers = []
        for block in program.blocks:
            for op in block.ops:
                if "Param" in op.input_names and "Grad" in op.input_names:
                    optimizers.append(op)

        actual_num_mp = 0
        for op in optimizers:
            if op.has_attr("multi_precision") and op.attr("multi_precision"):
                actual_num_mp += 1
        self.assertEqual(
            actual_num_mp,
            expected_num_mp,
            f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
        )

241 242 243 244 245
    def test_amp_bf16_o1(self):
        main_program, startup_program = build_embedding_model(
            True, "bfloat16", "O1"
        )
        self.assertEqual(main_program.num_blocks, 1)
246
        self._check_optimizer(main_program, 0)
247 248 249 250 251 252 253 254 255 256 257

        amp.debugging.collect_operator_stats(main_program)
        op_stats_list = amp.debugging._get_op_stats_list(main_program)
        expected_bf16_calls = {
            "matmul_v2": 1,
            "elementwise_add": 1,
            "dropout": 1,
            "lookup_table_v2": 0,
            "squared_l2_norm": 0,
            "adamw": 0,
        }
258
        self._check_op_calls(op_stats_list[0], expected_bf16_calls)
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

    def test_amp_bf16_o2(self):
        main_program, startup_program = build_embedding_model(
            True, "bfloat16", "O2"
        )
        self.assertEqual(main_program.num_blocks, 1)

        amp.debugging.collect_operator_stats(main_program)
        op_stats_list = amp.debugging._get_op_stats_list(main_program)
        expected_bf16_calls = {
            "matmul_v2": 1,
            "elementwise_add": 1,
            "dropout": 1,
            "lookup_table_v2": 0,
            "squared_l2_norm": 2,
            "adamw": 2,
        }
276 277 278 279 280
        self._check_optimizer(
            main_program,
            expected_bf16_calls["matmul_v2"]
            + expected_bf16_calls["elementwise_add"],
        )
281
        self._check_op_calls(op_stats_list[0], expected_bf16_calls)
282 283


284
class TestStaticBF16(AmpTestBase):
285 286 287 288 289 290 291
    def _generate_feed_x(self):
        x = np.random.random(size=[16, 16]).astype("float32")
        x_bf16 = convert_float_to_uint16(x)
        x_fp32 = convert_uint16_to_float(x_bf16)
        return x_fp32, x_bf16

    def test_compare_o1_o2(self):
292
        def _run(place, exe, x_np, max_iters, level):
293 294 295 296 297 298
            (
                main_program,
                startup_program,
                optimizer,
                feed_vars,
                fetch_vars,
299
            ) = build_add_model(True, "bfloat16", level)
300

301
            losses = self.run_program(
302 303 304 305 306
                main_program,
                startup_program,
                optimizer,
                feed_vars,
                fetch_vars,
307 308 309 310 311 312
                place,
                exe,
                x_np,
                max_iters,
                level,
            )
313 314 315 316
            return losses

        max_iters = 2
        x_fp32, x_bf16 = self._generate_feed_x()
317 318 319 320
        place = paddle.CUDAPlace(0)
        exe = paddle.static.Executor(place)
        losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
        losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
321 322


323 324
if __name__ == '__main__':
    unittest.main()