未验证 提交 7ccf6b60 编写于 作者: A arlesniak 提交者: GitHub

[oneDNN] Initial bf16 amp integration (#31093)

上级 a501a7b0
......@@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex64>,
ops::CastOpKernel<CPU, paddle::platform::complex128>);
......@@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>,
......
......@@ -20,7 +20,10 @@ from . import fp16_lists
from .fp16_lists import *
from . import fp16_utils
from .fp16_utils import *
from . import bf16
from .bf16 import *
__all__ = decorator.__all__
__all__ += fp16_lists.__all__
__all__ += fp16_utils.__all__
__all__ += bf16.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import amp_lists
from .amp_lists import *
from . import amp_utils
from .amp_utils import *
__all__ = []
__all__ += amp_lists.__all__
__all__ += amp_utils.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\
gray_list as gray_list_fp16, unsupported_fp16_list
__all__ = ["AutoMixedPrecisionListsBF16"]
class AutoMixedPrecisionListsBF16(object):
"""
AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an
algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined
fp32 list and bf16 list according to users' custom fp32 bf16 lists.
Args:
custom_bf16_list (set): Users' custom bf16 list.
custom_fp32_list (set): Users' custom fp32 list.
custom_fp32_varnames (set): Users' custom fp32 variables' names.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
with paddle.static.amp.bf16_guard():
paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'})
"""
def __init__(self,
custom_bf16_list=None,
custom_fp32_list=None,
custom_fp32_varnames=None):
self._custom_bf16_list = custom_bf16_list
self._custom_fp32_list = custom_fp32_list
self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list()
def _update_list(self):
"""
Update fp32 and bf16 list according to users' custom list.
"""
if self._custom_bf16_list and self._custom_fp32_list:
for op_name in self._custom_bf16_list:
if op_name in self._custom_fp32_list:
raise ValueError("Custom bf16 list overlap "
"custom fp32 list")
if self._custom_bf16_list:
for op_name in self._custom_bf16_list:
if op_name in self.fp32_list:
self.fp32_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.bf16_list.add(op_name)
if self._custom_fp32_list:
for op_name in self._custom_fp32_list:
if op_name in self.bf16_list:
self.bf16_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.fp32_list.add(op_name)
self.unsupported_list.add(op_name)
# always bf16
bf16_list = {'elementwise_add', }
# depends on the prev_op type
gray_list = {
'reshape2',
'lookup_table',
}
unsupported_list = unsupported_fp16_list.copy().copy()
fp32_list = black_list_fp16.copy().copy()
fp32_list |= white_list_fp16
fp32_list |= gray_list_fp16
fp32_list -= bf16_list
fp32_list -= gray_list
unsupported_list -= bf16_list
unsupported_list -= gray_list
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import struct
from .... import core
from .... import framework
from ....log_helper import get_logger
from ....wrapped_decorator import signature_safe_contextmanager
from .amp_lists import AutoMixedPrecisionListsBF16
from ..fp16_utils import find_true_prev_op, find_true_post_op, _rename_arg, find_op_index
import logging
import numpy as np
__all__ = ["bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16"]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
_valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
_bf16_guard_pattern = "__use_bf16__"
def convert_float_to_uint16(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<I', struct.pack('<f', x))[0] >> 16,
otypes=[np.uint16])(in_list.flat)
return np.reshape(out, in_list.shape)
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.BF16:
return 'bf16'
else:
return 'fp32'
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
dest_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops = 0
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]:
if in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.BF16:
for out_name in op.output_names:
if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
return num_cast_ops
def _is_in_fp32_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.fp32_varnames:
return True
for out_name in op.output_arg_names:
if out_name in amp_lists.fp32_varnames:
return True
return False
def _need_keep_fp32(op, unsupported_op_list, use_bf16_guard):
if op.type in unsupported_op_list:
# the highest priority condition: If ops don't have bf16 computing kernels,
# they must be executed in fp32 calculation pattern.
return True
# process ops about learning rate
in_out_arg_names = []
in_out_arg_names.extend(list(op.input_arg_names))
in_out_arg_names.extend(list(op.output_arg_names))
for name in in_out_arg_names:
if "learning_rate" in name:
return True
if use_bf16_guard:
if op.has_attr("op_namescope") and \
(_bf16_guard_pattern in op.attr("op_namescope")):
# op in bf16 guard
return False
else:
# op not in bf16 guard
return True
else:
return False
@signature_safe_contextmanager
def bf16_guard():
"""
As for the pure bf16 training, if users set `use_bf16_guard` to True,
only those ops created in the context manager `bf16_guard` will be
transformed as float16 type.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
"""
with framework.name_scope(prefix=_bf16_guard_pattern):
yield
def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the fp32 list, add it to fp32 set
2. When an op belongs to the bf16 list, add it to bf16 set
3. When an op belongs to the gray list. If one
of its inputs is the output of fp32 set op or fp32 list op,
add it to fp32 set. If all of its previous ops are not fp32
op and one of its inputs is the output of bf16 set op or
bf16 list op, add it to bf16 set.
4. When an op isn't in the lists, add it to fp32 op set.
5. Add necessary cast ops to make sure that fp32 set op will be
computed in fp32 mode, while bf16 set op will be computed in
bf16 mode.
Args:
main_prog (Program): The main program for training.
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionListsBF16()
block = main_prog.global_block()
ops = block.ops
bf16_op_set = set()
fp32_op_set = set()
for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.fp32_varnames is not None and _is_in_fp32_varnames(
op, amp_lists):
fp32_op_set.add(op)
continue
if op.type in amp_lists.fp32_list or _need_keep_fp32(
op, amp_lists.unsupported_list, use_bf16_guard):
fp32_op_set.add(op)
elif op.type in amp_lists.bf16_list:
bf16_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_fp32_op = False
is_bf16_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in fp32_op_set or \
prev_op.type in amp_lists.fp32_list:
is_fp32_op = True
elif prev_op in bf16_op_set or \
prev_op.type in amp_lists.bf16_list:
is_bf16_op = True
if is_fp32_op:
fp32_op_set.add(op)
elif is_bf16_op:
bf16_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
fp32_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in fp32_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32)
elif op in bf16_op_set:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
op._set_attr('mkldnn_data_type', 'bfloat16')
elif op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16)
else:
pass
idx += num_cast_ops + 1
......@@ -69,7 +69,7 @@ class AutoMixedPrecisionLists(object):
self.unsupported_list.add(op_name)
# The three sets listed below are changed dynamiclly. They don't contain all
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision as amp
from paddle.fluid import core
import paddle
paddle.enable_static()
class AMPTest(unittest.TestCase):
def setUp(self):
self.bf16_list = copy.copy(amp.bf16.amp_lists.bf16_list)
self.fp32_list = copy.copy(amp.bf16.amp_lists.fp32_list)
self.gray_list = copy.copy(amp.bf16.amp_lists.gray_list)
self.amp_lists_ = None
def tearDown(self):
self.assertEqual(self.amp_lists_.bf16_list, self.bf16_list)
self.assertEqual(self.amp_lists_.fp32_list, self.fp32_list)
self.assertEqual(self.amp_lists_.gray_list, self.gray_list)
def test_amp_lists(self):
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16()
def test_amp_lists_1(self):
# 1. w={'exp}, b=None
self.bf16_list.add('exp')
self.fp32_list.remove('exp')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'})
def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh')
self.bf16_list.add('tanh')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'})
def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None
self.bf16_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'}
self.bf16_list.remove('elementwise_add')
self.fp32_list.add('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
def test_amp_lists_5(self):
# 5. w=None, b={'elementwise_add'}
self.fp32_list.add('elementwise_add')
self.bf16_list.remove('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
def test_amp_lists_6(self):
# 6. w=None, b={'lstm'}
self.fp32_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'lstm'})
def test_amp_lists_7(self):
self.fp32_list.add('reshape2')
self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_fp32_list={'reshape2'})
def test_amp_list_8(self):
self.bf16_list.add('reshape2')
self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16(
custom_bf16_list={'reshape2'})
class AMPTest2(unittest.TestCase):
def test_amp_lists_(self):
# 7. w={'lstm'} b={'lstm'}
# raise ValueError
self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16,
{'lstm'}, {'lstm'})
def test_find_op_index(self):
block = fluid.default_main_program().global_block()
op_desc = core.OpDesc()
idx = amp.bf16.amp_utils.find_op_index(block.desc, op_desc)
assert (idx == -1)
def test_is_in_fp32_varnames(self):
block = fluid.default_main_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
op2 = block.append_op(
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
amp_lists_1 = amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'X'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
amp_lists_2 = amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'Y'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
def test_find_true_post_op(self):
block = fluid.default_main_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
var3 = block.create_var(name="Z", shape=[3], dtype='float32')
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
op2 = block.append_op(
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
assert (res == [op2])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
import paddle.fluid.layers as layers
import paddle.static.amp as amp
from paddle.fluid import core
paddle.enable_static()
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestModelCastBF16(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.seed = 111
@classmethod
def tearDownClass(cls):
pass
@contextlib.contextmanager
def static_graph(self):
with self.scope_prog_guard():
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
@contextlib.contextmanager
def scope_prog_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
def get_static_graph_result(self, feed, fetch_list, amp_fun,
with_lod=False):
exe = fluid.Executor(core.CPUPlace())
exe.run(fluid.default_startup_program())
prog = fluid.default_main_program()
if amp_fun is not None:
amp_fun(prog)
return exe.run(prog,
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
def test_graph_rewrite(self):
size = 3
n = np.ones([size, size], dtype='float32') * 3.2
nn = np.ones([size, size], dtype='float32') * -2.7
n_bf16 = amp.convert_float_to_uint16(n)
nn_bf16 = amp.convert_float_to_uint16(nn)
with self.static_graph():
t_bf16 = layers.data(
name='t_bf16', shape=[size, size], dtype=np.uint16)
tt_bf16 = layers.data(
name='tt_bf16', shape=[size, size], dtype=np.uint16)
t = layers.data(name='t', shape=[size, size], dtype='float32')
tt = layers.data(name='tt', shape=[size, size], dtype='float32')
ret = layers.elementwise_add(t, tt)
ret = layers.elementwise_mul(ret, t)
ret = layers.reshape(ret, [0, 0])
with amp.bf16_guard():
ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16)
ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16)
ret_bf16 = layers.reshape(ret_bf16, [0, 0])
with amp.bf16_guard():
ret_fp32bf16 = layers.elementwise_add(t, tt)
ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t)
ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0])
static_ret_bf16, static_ret, ret_fp32bf16 = self.get_static_graph_result(
feed={
't': n,
'tt': nn,
't_bf16': n_bf16,
'tt_bf16': nn_bf16,
},
fetch_list=[ret_bf16, ret, ret_fp32bf16],
amp_fun=lambda prog: amp.rewrite_program_bf16(prog, use_bf16_guard=True))
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
with self.static_graph():
t = layers.data(name='t', shape=[size, size], dtype='float32')
tt = layers.data(name='tt', shape=[size, size], dtype='float32')
with amp.bf16_guard():
ret = layers.elementwise_add(t, tt)
ret = layers.reshape(ret, [0, 0], act='elu')
ret = layers.elementwise_mul(ret, t)
ret = layers.elementwise_add(ret, tt)
static_ret_bf16 = \
self.get_static_graph_result(
feed={'t': n, 'tt': nn},
fetch_list=[ret],
amp_fun=lambda prog: amp.rewrite_program_bf16(
prog,
amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
use_bf16_guard=True
)
)
self.assertTrue(
static_ret_bf16, np.ones(
[size, size], dtype='float32') * -1.1)
if __name__ == '__main__':
unittest.main()
......@@ -29,6 +29,7 @@ __all__ = ['DataFeeder']
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
core.VarDesc.VarType.BOOL: 'bool',
core.VarDesc.VarType.FP16: 'float16',
core.VarDesc.VarType.BF16: 'uint16',
core.VarDesc.VarType.FP32: 'float32',
core.VarDesc.VarType.FP64: 'float64',
core.VarDesc.VarType.INT8: 'int8',
......@@ -47,16 +48,18 @@ def convert_dtype(dtype):
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
np.int32, np.int64, np.uint8, np.complex64, np.complex128
np.bool, np.float16, np.uint16, np.float32, np.float64, np.int8,
np.int16, np.int32, np.int64, np.uint8, np.complex64,
np.complex128
]:
return dtype.__name__
else:
if dtype in [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
'int32', 'int64', 'uint8', 'complex64', 'complex128', u'bool',
u'float16', u'float32', u'float64', u'int8', u'int16', u'int32',
u'int64', u'uint8', u'complex64', u'complex128'
'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
u'int16', u'int32', u'int64', u'uint8', u'complex64',
u'complex128'
]:
# this code is a little bit dangerous, since error could happen
# when casting no-ascii code to str in python2.
......@@ -66,7 +69,7 @@ def convert_dtype(dtype):
return str(dtype)
raise TypeError(
"dtype must be any of [bool, float16, float32, float64, int8, int16, "
"dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, "
"int32, int64, uint8, complex64, complex128], but received %s" % dtype)
......@@ -123,6 +126,12 @@ def check_dtype(input_dtype,
warnings.warn(
"The data type of '%s' in %s only support float16 in GPU now. %s" %
(input_name, op_name, extra_message))
if convert_dtype(input_dtype) in ['uint16'] and op_name not in [
'reshape', 'lookup_table', 'scale'
]:
warnings.warn(
"The data type of '%s' in %s only support bfloat16 in OneDNN now. %s"
% (input_name, op_name, extra_message))
if convert_dtype(input_dtype) not in expected_dtype:
raise TypeError(
"The data type of '%s' in %s must be %s, but received %s. %s" %
......
......@@ -6137,9 +6137,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64',
'bool'], 'reshape')
check_variable_and_dtype(x, 'x', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'
], 'reshape')
check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
......@@ -11354,9 +11354,11 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type)
x, 'x', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type)
y, 'y', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
......@@ -11428,8 +11430,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
return dygraph_utils._append_activation_in_dygraph(out)
check_variable_and_dtype(x, "x", [
'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
'uint8'
'float16', 'uint16', 'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8'
], "scale")
inputs = {'X': [x]}
attrs = {
......
......@@ -26,7 +26,7 @@ import os
paddle.enable_static()
def train(use_cuda, save_dirname, is_local):
def train(use_cuda, save_dirname, is_local, use_bf16):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
......@@ -37,6 +37,8 @@ def train(use_cuda, save_dirname, is_local):
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
if use_bf16:
paddle.static.amp.rewrite_program_bf16(fluid.default_main_program())
sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 20
......@@ -133,14 +135,17 @@ def infer(use_cuda, save_dirname=None):
print("ground truth: ", test_label)
def main(use_cuda, is_local=True):
def main(use_cuda, is_local=True, use_bf16=False):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
if use_bf16 and not fluid.core.is_compiled_with_mkldnn():
return
# Directory for saving the trained model
save_dirname = "fit_a_line.inference.model"
train(use_cuda, save_dirname, is_local)
train(use_cuda, save_dirname, is_local, use_bf16)
infer(use_cuda, save_dirname)
......@@ -153,6 +158,12 @@ class TestFitALine(unittest.TestCase):
with self.program_scope_guard():
main(use_cuda=True)
@unittest.skipIf(not fluid.core.supports_bfloat16(),
"place does not support BF16 evaluation")
def test_bf16(self):
with self.program_scope_guard():
main(use_cuda=False, use_bf16=True)
@contextlib.contextmanager
def program_scope_guard(self):
prog = fluid.Program()
......
......@@ -39,7 +39,12 @@ def get_place(target):
format(target))
def train(target, is_sparse, is_parallel, save_dirname, is_local=True):
def train(target,
is_sparse,
is_parallel,
save_dirname,
is_local=True,
use_bf16=False):
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
......@@ -101,6 +106,8 @@ def train(target, is_sparse, is_parallel, save_dirname, is_local=True):
raise NotImplementedError()
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
if use_bf16:
paddle.static.amp.rewrite_program_bf16(fluid.default_main_program())
sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch(
......@@ -239,12 +246,15 @@ def infer(target, save_dirname=None):
assert np.isclose(a, b, rtol=5e-5), "a: {}, b: {}".format(a, b)
def main(target, is_sparse, is_parallel):
def main(target, is_sparse, is_parallel, use_bf16):
if target == "cuda" and not fluid.core.is_compiled_with_cuda():
return
if target == "xpu" and not fluid.core.is_compiled_with_xpu():
return
if use_bf16 and not fluid.core.is_compiled_with_mkldnn():
return
if not is_parallel:
save_dirname = "word2vec.inference.model"
else:
......@@ -255,7 +265,7 @@ def main(target, is_sparse, is_parallel):
# so only inference is turned on.
train("cpu", is_sparse, is_parallel, save_dirname)
else:
train(target, is_sparse, is_parallel, save_dirname)
train(target, is_sparse, is_parallel, save_dirname, use_bf16=use_bf16)
infer(target, save_dirname)
......@@ -268,10 +278,11 @@ class W2VTest(unittest.TestCase):
pass
def inject_test_method(target, is_sparse, is_parallel):
fn_name = "test_{0}_{1}_{2}".format(target, "sparse"
if is_sparse else "dense", "parallel"
if is_parallel else "normal")
def inject_test_method(target, is_sparse, is_parallel, use_bf16=False):
fn_name = "test_{0}_{1}_{2}{3}".format(target, "sparse"
if is_sparse else "dense", "parallel"
if is_parallel else "normal", "_bf16"
if use_bf16 else "")
def __impl__(*args, **kwargs):
prog = fluid.Program()
......@@ -279,8 +290,7 @@ def inject_test_method(target, is_sparse, is_parallel):
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
main(
target=target, is_sparse=is_sparse, is_parallel=is_parallel)
main(target, is_sparse, is_parallel, use_bf16)
if (not fluid.core.is_compiled_with_cuda() or
target == "cuda") and is_sparse:
......@@ -297,6 +307,7 @@ for target in ("cuda", "cpu", "xpu"):
for is_sparse in (False, True):
for is_parallel in (False, ):
inject_test_method(target, is_sparse, is_parallel)
inject_test_method("cpu", False, False, use_bf16=True)
if __name__ == '__main__':
unittest.main()
......@@ -244,17 +244,12 @@ def convert_float_to_uint16(float_list, data_format="NCHW"):
return new_output
def copy_bits_from_uint16_to_float(i):
i = np.uint32(i) << 16
return struct.unpack('<f', struct.pack('<I', i))[0]
def convert_uint16_to_float(uint16_list):
new_output = []
for x in np.nditer(uint16_list):
new_output.append(np.float32(copy_bits_from_uint16_to_float(x)))
return np.reshape(new_output, uint16_list.shape).view(np.float32)
def convert_uint16_to_float(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32])(in_list.flat)
return np.reshape(out, in_list.shape)
class OpTest(unittest.TestCase):
......
......@@ -14,5 +14,8 @@
from ...fluid.contrib import mixed_precision
from ...fluid.contrib.mixed_precision import *
from ...fluid.contrib.mixed_precision import bf16
from ...fluid.contrib.mixed_precision.bf16 import *
__all__ = mixed_precision.__all__
__all__ += bf16.__all__
......@@ -179,6 +179,7 @@ packages=['paddle',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.mixed_precision.bf16',
'paddle.fluid.contrib.layers',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
......
......@@ -219,6 +219,7 @@ CPU_PARALLEL_JOB = [
'test_full_op',
'test_framework_debug_str',
'test_fp16_utils',
'test_bf16_utils',
'test_fleet_rolemaker_4',
'test_flags_use_mkldnn',
'test_filter_by_instag_op',
......
......@@ -699,4 +699,5 @@ STATIC_MODE_TESTING_LIST = [
'test_slice_op_xpu',
'test_generate_proposals_v2_op',
'test_lamb_op_xpu',
'test_model_cast_to_bf16',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册