未验证 提交 bc379ca3 编写于 作者: A arlesniak 提交者: GitHub

Added pure_bf16 mode (#32281)

上级 9aad7527
...@@ -162,6 +162,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, ...@@ -162,6 +162,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel, ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, bool, int64_t, ops::AssignKernel, bool,
ops::AssignKernel, plat::float16, ops::AssignKernel, plat::float16,
ops::AssignKernel, plat::bfloat16,
ops::AssignKernel); ops::AssignKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
...@@ -20,10 +20,7 @@ from . import fp16_lists ...@@ -20,10 +20,7 @@ from . import fp16_lists
from .fp16_lists import * from .fp16_lists import *
from . import fp16_utils from . import fp16_utils
from .fp16_utils import * from .fp16_utils import *
from . import bf16
from .bf16 import *
__all__ = decorator.__all__ __all__ = decorator.__all__
__all__ += fp16_lists.__all__ __all__ += fp16_lists.__all__
__all__ += fp16_utils.__all__ __all__ += fp16_utils.__all__
__all__ += bf16.__all__
...@@ -18,7 +18,9 @@ from . import amp_lists ...@@ -18,7 +18,9 @@ from . import amp_lists
from .amp_lists import * from .amp_lists import *
from . import amp_utils from . import amp_utils
from .amp_utils import * from .amp_utils import *
from . import decorator
from .decorator import *
__all__ = [] __all__ = decorator.__all__
__all__ += amp_lists.__all__ __all__ += amp_lists.__all__
__all__ += amp_utils.__all__ __all__ += amp_utils.__all__
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import copy import copy
from paddle.fluid import core
from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\ from ..fp16_lists import white_list as white_list_fp16, black_list as black_list_fp16,\
gray_list as gray_list_fp16, unsupported_fp16_list gray_list as gray_list_fp16
__all__ = ["AutoMixedPrecisionListsBF16"] __all__ = ["AutoMixedPrecisionListsBF16"]
...@@ -82,11 +84,17 @@ bf16_list = {'elementwise_add', } ...@@ -82,11 +84,17 @@ bf16_list = {'elementwise_add', }
# depends on the prev_op type # depends on the prev_op type
gray_list = { gray_list = {
'cast',
'fill_constant',
'reduce_mean',
'reshape2', 'reshape2',
'lookup_table', 'scale',
} }
unsupported_list = unsupported_fp16_list.copy().copy() _, _, _sys_unsupported_bf16_list = core.op_supported_infos(
'CPU', core.VarDesc.VarType.BF16)
unsupported_list = _sys_unsupported_bf16_list
fp32_list = black_list_fp16.copy().copy() fp32_list = black_list_fp16.copy().copy()
fp32_list |= white_list_fp16 fp32_list |= white_list_fp16
fp32_list |= gray_list_fp16 fp32_list |= gray_list_fp16
......
...@@ -14,18 +14,25 @@ ...@@ -14,18 +14,25 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import struct
from .... import core from .... import core
from .... import framework from .... import framework
from .... import global_scope
from ....log_helper import get_logger from ....log_helper import get_logger
from ....wrapped_decorator import signature_safe_contextmanager from ....wrapped_decorator import signature_safe_contextmanager
from .amp_lists import AutoMixedPrecisionListsBF16 from .amp_lists import AutoMixedPrecisionListsBF16
from ..fp16_utils import find_true_prev_op, find_true_post_op, _rename_arg, find_op_index from ..fp16_utils import find_true_prev_op, find_true_post_op, _rename_arg, \
find_op_index, _rename_op_input
import collections
import struct
import logging import logging
import numpy as np import numpy as np
__all__ = ["bf16_guard", "rewrite_program_bf16", "convert_float_to_uint16"] __all__ = [
"bf16_guard", "rewrite_program_bf16", "cast_model_to_bf16",
"cast_parameters_to_bf16", "convert_float_to_uint16"
]
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
...@@ -126,7 +133,41 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -126,7 +133,41 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
return num_cast_ops return num_cast_ops
def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
op_var_rename_map):
num_cast_ops = 0
target_var = block.var(target_name)
if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
return num_cast_ops
assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient)
block._insert_op(
idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype})
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
return num_cast_ops
def _is_in_fp32_varnames(op, amp_lists): def _is_in_fp32_varnames(op, amp_lists):
if not amp_lists.fp32_varnames:
return False
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
if in_name in amp_lists.fp32_varnames: if in_name in amp_lists.fp32_varnames:
return True return True
...@@ -191,7 +232,174 @@ def bf16_guard(): ...@@ -191,7 +232,174 @@ def bf16_guard():
yield yield
def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special processing for
the batch normalization, which will keep the batchnorm's computations in FP32.
Args:
program (Program): The used program.
amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object.
use_bf16_guard(bool): Determine whether to use `bf16_guard` when
constructing the program. Default True.
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionListsBF16()
global_block = program.global_block()
keep_fp32_ops = set()
to_bf16_var_names = set()
to_bf16_pre_cast_ops = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)
for block in program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_bf16_guard):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(in_var_name))
if in_var is None or in_var.type not in _valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
to_bf16_var_names.add(in_var_name)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(out_var_name))
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".
format(op.type, out_var_name, out_var.dtype))
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if op.has_attr(attr_name) and op.attr(
attr_name) == core.VarDesc.VarType.FP32:
op._set_attr(attr_name, core.VarDesc.VarType.BF16)
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op not in keep_fp32_ops:
if op in to_bf16_pre_cast_ops:
in_var_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16)
num_cast_ops += in_var_cast_num
else:
pre_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.BF16,
core.VarDesc.VarType.FP32)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.BF16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.BF16, out_var_name,
op_var_rename_map)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1
_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
return to_bf16_var_names
def cast_parameters_to_bf16(place, program, scope=None, to_bf16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the BF16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names`
will be set to BF16. Usually, it is the returned
value of `cast_model_to_bf16` API.
"""
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
bf16_var_names = to_bf16_var_names if to_bf16_var_names else set()
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name in bf16_var_names:
_logger.debug("---- cast {} to bf16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(convert_float_to_uint16(data), place)
def rewrite_program_bf16(main_prog, amp_lists=None):
""" """
Traverse all ops in current block and insert cast op according to Traverse all ops in current block and insert cast op according to
which set current op belongs to. which set current op belongs to.
...@@ -231,8 +439,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False): ...@@ -231,8 +439,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False):
fp32_op_set.add(op) fp32_op_set.add(op)
continue continue
if op.type in amp_lists.fp32_list or _need_keep_fp32( if op.type in amp_lists.fp32_list:
op, amp_lists.unsupported_list, use_bf16_guard):
fp32_op_set.add(op) fp32_op_set.add(op)
elif op.type in amp_lists.bf16_list: elif op.type in amp_lists.bf16_list:
bf16_op_set.add(op) bf16_op_set.add(op)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import (core, default_main_program, layers, program_guard,
unique_name)
from .amp_utils import (rewrite_program_bf16, cast_model_to_bf16,
cast_parameters_to_bf16)
from .amp_lists import AutoMixedPrecisionListsBF16
import types
import warnings
__all__ = ["decorate_bf16"]
class OptimizerWithMixedPrecision(object):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (CustomOpLists): An CustomOpLists object.
use_pure_bf16(bool): Whether to use the pure bf16 training.
use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program.
"""
def __init__(self, optimizer, amp_lists, use_pure_bf16, use_bf16_guard):
self._optimizer = optimizer
self._amp_lists = amp_lists
self._param_grads = None
self._train_program = None
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_bf16 = use_pure_bf16
self._use_bf16_guard = use_bf16_guard
self._to_bf16_var_names = None
def _init_amp_var(self):
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(self._optimizer._learning_rate, float):
self._optimizer._learning_rate_map[default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._optimizer._learning_rate),
dtype='float32',
persistable=True)
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
Backward propagation or auto differentiation for gradients' computation.
Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
callbacks (list|None): A list of callable objects to run when appending
backward operator for one parameter.
Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
train_program = loss.block.program
self._train_program = train_program
with program_guard(self._train_program, startup_program):
self._init_amp_var()
if self._use_pure_bf16:
self._to_bf16_var_names = cast_model_to_bf16(
self._train_program, self._amp_lists, self._use_bf16_guard)
else:
rewrite_program_bf16(self._train_program, self._amp_lists)
if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32')
params_grads = self._optimizer.backward(
loss, startup_program, parameter_list, no_grad_set, callbacks)
return params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_bf16_test=False):
"""
Init the amp training, such as cast fp32 parameters to bf16 type.
Args:
place(CPUPlace): place is used to initialize
bf16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_bf16_test(bool): Whether to use bf16 testing.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
def run_example_code():
place = paddle.CPUPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use bf16_guard to control the range of bf16 kernels used.
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_fp32_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_fp32_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure bf16 training by setting `use_pure_bf16` to True.
optimizer = paddle.static.amp.bf16.decorate_bf16(
optimizer,
amp_list,
use_pure_bf16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
assert self._train_program is not None, \
"Please call the minimize method first."
if self._use_pure_bf16:
cast_parameters_to_bf16(place, self._train_program, scope,
self._to_bf16_var_names)
if test_program is not None:
if self._use_pure_bf16:
cast_model_to_bf16(test_program, self._amp_lists,
self._use_bf16_guard)
elif use_bf16_test:
rewrite_program_bf16(test_program, self._amp_lists)
def apply_gradients(self, params_grads):
"""
Apply gradients.
Args:
params_grads (list): A list of params.
Returns:
A list of optimize operators.
"""
return self._optimizer.apply_gradients(params_grads)
def apply_optimize(self, loss, startup_program, params_grads):
program = loss.block.program
with program_guard(program, startup_program):
optimize_ops = self.apply_gradients(params_grads)
return optimize_ops
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""
Perform optimization by minimizing the given loss.
Args:
loss (Variable): The loss Variable.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
opt_dict = self._optimizer.__class__.__dict__
if 'minimize' in opt_dict and isinstance(opt_dict['minimize'],
types.FunctionType):
warnings.warn(
"The decorated optimizer has its own `minimize` method, but it will not be executed."
)
params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
optimize_ops = self.apply_optimize(loss, startup_program, params_grads)
return optimize_ops, params_grads
def decorate_bf16(optimizer,
amp_lists=None,
use_pure_bf16=False,
use_bf16_guard=None):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (CustomOpLists): An CustomOpLists object.
use_pure_bf16(bool): Whether to use the pure bf16 training. Default False.
use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program.
Default None, which means that its value equals to `use_pure_bf16`.
Returns:
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples 1:
.. code-block:: python
# fp32&bf16 list based strategy example
import paddle
import paddle.static as static
paddle.enable_static()
data = static.data(name='X', shape=[None, 1], dtype='float32')
hidden = static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
mp_optimizer = static.amp.decorate_bf16(optimizer=optimizer)
ops, param_grads = mp_optimizer.minimize(loss)
Examples 2:
.. code-block:: python
# pure bf16 training example
import numpy as np
import paddle
import paddle.nn.functional as F
def run_example_code():
place = paddle.CPUPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use bf16_guard to control the range of bf16 kernels used.
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_fp32_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_fp32_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure bf16 training by setting `use_pure_bf16` to True.
optimizer = paddle.static.amp.decorate_bf16(
optimizer,
amp_list,
use_pure_bf16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionListsBF16()
if use_bf16_guard is None:
use_bf16_guard = use_pure_bf16
mp_optimizer = OptimizerWithMixedPrecision(optimizer, amp_lists,
use_pure_bf16, use_bf16_guard)
return mp_optimizer
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import copy import copy
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision as amp import paddle.static.amp as amp
from paddle.fluid import core from paddle.fluid import core
import paddle import paddle
...@@ -34,34 +34,34 @@ class AMPTest(unittest.TestCase): ...@@ -34,34 +34,34 @@ class AMPTest(unittest.TestCase):
self.assertEqual(self.amp_lists_.gray_list, self.gray_list) self.assertEqual(self.amp_lists_.gray_list, self.gray_list)
def test_amp_lists(self): def test_amp_lists(self):
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16() self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16()
def test_amp_lists_1(self): def test_amp_lists_1(self):
# 1. w={'exp}, b=None # 1. w={'exp}, b=None
self.bf16_list.add('exp') self.bf16_list.add('exp')
self.fp32_list.remove('exp') self.fp32_list.remove('exp')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'exp'}) self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'exp'})
def test_amp_lists_2(self): def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None # 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh') self.fp32_list.remove('tanh')
self.bf16_list.add('tanh') self.bf16_list.add('tanh')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'tanh'}) self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
def test_amp_lists_3(self): def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None # 3. w={'lstm'}, b=None
self.bf16_list.add('lstm') self.bf16_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16({'lstm'}) self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_4(self): def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'} # 4. w=None, b={'elementwise_add'}
self.bf16_list.remove('elementwise_add') self.bf16_list.remove('elementwise_add')
self.fp32_list.add('elementwise_add') self.fp32_list.add('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'}) custom_fp32_list={'elementwise_add'})
def test_amp_lists_5(self): def test_amp_lists_5(self):
...@@ -69,28 +69,28 @@ class AMPTest(unittest.TestCase): ...@@ -69,28 +69,28 @@ class AMPTest(unittest.TestCase):
self.fp32_list.add('elementwise_add') self.fp32_list.add('elementwise_add')
self.bf16_list.remove('elementwise_add') self.bf16_list.remove('elementwise_add')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'}) custom_fp32_list={'elementwise_add'})
def test_amp_lists_6(self): def test_amp_lists_6(self):
# 6. w=None, b={'lstm'} # 6. w=None, b={'lstm'}
self.fp32_list.add('lstm') self.fp32_list.add('lstm')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'lstm'}) custom_fp32_list={'lstm'})
def test_amp_lists_7(self): def test_amp_lists_7(self):
self.fp32_list.add('reshape2') self.fp32_list.add('reshape2')
self.gray_list.remove('reshape2') self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'reshape2'}) custom_fp32_list={'reshape2'})
def test_amp_list_8(self): def test_amp_list_8(self):
self.bf16_list.add('reshape2') self.bf16_list.add('reshape2')
self.gray_list.remove('reshape2') self.gray_list.remove('reshape2')
self.amp_lists_ = amp.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'reshape2'}) custom_bf16_list={'reshape2'})
...@@ -98,7 +98,7 @@ class AMPTest2(unittest.TestCase): ...@@ -98,7 +98,7 @@ class AMPTest2(unittest.TestCase):
def test_amp_lists_(self): def test_amp_lists_(self):
# 7. w={'lstm'} b={'lstm'} # 7. w={'lstm'} b={'lstm'}
# raise ValueError # raise ValueError
self.assertRaises(ValueError, amp.AutoMixedPrecisionListsBF16, self.assertRaises(ValueError, amp.bf16.AutoMixedPrecisionListsBF16,
{'lstm'}, {'lstm'}) {'lstm'}, {'lstm'})
def test_find_op_index(self): def test_find_op_index(self):
...@@ -117,10 +117,10 @@ class AMPTest2(unittest.TestCase): ...@@ -117,10 +117,10 @@ class AMPTest2(unittest.TestCase):
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
op2 = block.append_op( op2 = block.append_op(
type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]}) type="abs", inputs={"X": [var2]}, outputs={"Out": [var3]})
amp_lists_1 = amp.AutoMixedPrecisionListsBF16( amp_lists_1 = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'X'}) custom_fp32_varnames={'X'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1) assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_1)
amp_lists_2 = amp.AutoMixedPrecisionListsBF16( amp_lists_2 = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'Y'}) custom_fp32_varnames={'Y'})
assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2) assert amp.bf16.amp_utils._is_in_fp32_varnames(op2, amp_lists_2)
assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2) assert amp.bf16.amp_utils._is_in_fp32_varnames(op1, amp_lists_2)
......
...@@ -65,13 +65,13 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -65,13 +65,13 @@ class TestModelCastBF16(unittest.TestCase):
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=(not with_lod)) return_numpy=(not with_lod))
def test_graph_rewrite(self): def _graph_common(self, _amp_fun):
size = 3 size = 3
n = np.ones([size, size], dtype='float32') * 3.2 n = np.ones([size, size], dtype='float32') * 3.2
nn = np.ones([size, size], dtype='float32') * -2.7 nn = np.ones([size, size], dtype='float32') * -2.7
n_bf16 = amp.convert_float_to_uint16(n) n_bf16 = amp.bf16.convert_float_to_uint16(n)
nn_bf16 = amp.convert_float_to_uint16(nn) nn_bf16 = amp.bf16.convert_float_to_uint16(nn)
with self.static_graph(): with self.static_graph():
t_bf16 = layers.data( t_bf16 = layers.data(
...@@ -85,12 +85,12 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -85,12 +85,12 @@ class TestModelCastBF16(unittest.TestCase):
ret = layers.elementwise_mul(ret, t) ret = layers.elementwise_mul(ret, t)
ret = layers.reshape(ret, [0, 0]) ret = layers.reshape(ret, [0, 0])
with amp.bf16_guard(): with amp.bf16.bf16_guard():
ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16)
ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16)
ret_bf16 = layers.reshape(ret_bf16, [0, 0]) ret_bf16 = layers.reshape(ret_bf16, [0, 0])
with amp.bf16_guard(): with amp.bf16.bf16_guard():
ret_fp32bf16 = layers.elementwise_add(t, tt) ret_fp32bf16 = layers.elementwise_add(t, tt)
ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t) ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t)
ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0]) ret_fp32bf16 = layers.reshape(ret_fp32bf16, [0, 0])
...@@ -103,7 +103,7 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -103,7 +103,7 @@ class TestModelCastBF16(unittest.TestCase):
'tt_bf16': nn_bf16, 'tt_bf16': nn_bf16,
}, },
fetch_list=[ret_bf16, ret, ret_fp32bf16], fetch_list=[ret_bf16, ret, ret_fp32bf16],
amp_fun=lambda prog: amp.rewrite_program_bf16(prog, use_bf16_guard=True)) amp_fun=lambda prog: amp.bf16.rewrite_program_bf16(prog))
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2)) self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
...@@ -112,7 +112,7 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -112,7 +112,7 @@ class TestModelCastBF16(unittest.TestCase):
t = layers.data(name='t', shape=[size, size], dtype='float32') t = layers.data(name='t', shape=[size, size], dtype='float32')
tt = layers.data(name='tt', shape=[size, size], dtype='float32') tt = layers.data(name='tt', shape=[size, size], dtype='float32')
with amp.bf16_guard(): with amp.bf16.bf16_guard():
ret = layers.elementwise_add(t, tt) ret = layers.elementwise_add(t, tt)
ret = layers.reshape(ret, [0, 0], act='elu') ret = layers.reshape(ret, [0, 0], act='elu')
ret = layers.elementwise_mul(ret, t) ret = layers.elementwise_mul(ret, t)
...@@ -122,17 +122,27 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -122,17 +122,27 @@ class TestModelCastBF16(unittest.TestCase):
self.get_static_graph_result( self.get_static_graph_result(
feed={'t': n, 'tt': nn}, feed={'t': n, 'tt': nn},
fetch_list=[ret], fetch_list=[ret],
amp_fun=lambda prog: amp.rewrite_program_bf16( amp_fun=_amp_fun
prog,
amp.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
use_bf16_guard=True
)
) )
self.assertTrue( self.assertTrue(
static_ret_bf16, np.ones( static_ret_bf16, np.ones(
[size, size], dtype='float32') * -1.1) [size, size], dtype='float32') * -1.1)
def test_graph_rewrite(self):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
))
def test_graph_cast(self):
self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True
))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -332,7 +332,8 @@ def fc(input, ...@@ -332,7 +332,8 @@ def fc(input,
for i, input_x in enumerate(input): for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc') check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc')
dtype = helper.input_dtype() dtype = helper.input_dtype()
check_dtype(dtype, 'input', ['float16', 'float32', 'float64'], 'fc') check_dtype(dtype, 'input', ['float16', 'uint16', 'float32', 'float64'],
'fc')
mul_results = [] mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params(): for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape input_shape = input_var.shape
......
...@@ -582,10 +582,9 @@ def assign(input, output=None): ...@@ -582,10 +582,9 @@ def assign(input, output=None):
input = numpy.array(input) input = numpy.array(input)
if isinstance(input, Variable): if isinstance(input, Variable):
check_dtype( check_dtype(input.dtype, 'input', [
input.dtype, 'input', 'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', 'bool'
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], ], 'assign', '(When the type of input in assign is Variable.)')
'assign', '(When the type of input in assign is Variable.)')
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.static.amp as amp
import contextlib import contextlib
import numpy import numpy
import unittest import unittest
...@@ -26,19 +28,34 @@ import os ...@@ -26,19 +28,34 @@ import os
paddle.enable_static() paddle.enable_static()
def train(use_cuda, save_dirname, is_local, use_bf16): def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
x = fluid.layers.data(name='x', shape=[13], dtype='float32') x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y) if use_bf16:
avg_cost = fluid.layers.mean(cost) if not pure_bf16:
with amp.bf16.bf16_guard():
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
else:
y_predict = fluid.layers.fc(input=x, size=1, act=None)
with amp.bf16.bf16_guard():
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
else:
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
if use_bf16: if use_bf16:
paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer = amp.bf16.decorate_bf16(
sgd_optimizer,
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(),
use_bf16_guard=False,
use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 20 BATCH_SIZE = 20
...@@ -54,6 +71,10 @@ def train(use_cuda, save_dirname, is_local, use_bf16): ...@@ -54,6 +71,10 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
def train_loop(main_program): def train_loop(main_program):
feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
test_prog = main_program.clone(for_test=True)
if pure_bf16:
sgd_optimizer.amp_init(
exe.place, test_program=test_prog, use_bf16_test=True)
PASS_NUM = 100 PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
...@@ -61,9 +82,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16): ...@@ -61,9 +82,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
avg_loss_value, = exe.run(main_program, avg_loss_value, = exe.run(main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost]) fetch_list=[avg_cost])
print(avg_loss_value) if avg_loss_value[0] < 10.0 or pure_bf16:
if avg_loss_value[0] < 10.0: if save_dirname is not None and not pure_bf16:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, ['x'], fluid.io.save_inference_model(save_dirname, ['x'],
[y_predict], exe) [y_predict], exe)
return return
...@@ -97,7 +117,7 @@ def train(use_cuda, save_dirname, is_local, use_bf16): ...@@ -97,7 +117,7 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
train_loop(t.get_trainer_program()) train_loop(t.get_trainer_program())
def infer(use_cuda, save_dirname=None): def infer(use_cuda, save_dirname=None, use_bf16=False):
if save_dirname is None: if save_dirname is None:
return return
...@@ -135,7 +155,7 @@ def infer(use_cuda, save_dirname=None): ...@@ -135,7 +155,7 @@ def infer(use_cuda, save_dirname=None):
print("ground truth: ", test_label) print("ground truth: ", test_label)
def main(use_cuda, is_local=True, use_bf16=False): def main(use_cuda, is_local=True, use_bf16=False, pure_bf16=False):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
...@@ -145,11 +165,22 @@ def main(use_cuda, is_local=True, use_bf16=False): ...@@ -145,11 +165,22 @@ def main(use_cuda, is_local=True, use_bf16=False):
# Directory for saving the trained model # Directory for saving the trained model
save_dirname = "fit_a_line.inference.model" save_dirname = "fit_a_line.inference.model"
train(use_cuda, save_dirname, is_local, use_bf16) train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16)
infer(use_cuda, save_dirname) infer(use_cuda, save_dirname, use_bf16)
class TestFitALineBase(unittest.TestCase):
@contextlib.contextmanager
def program_scope_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
class TestFitALine(unittest.TestCase): class TestFitALine(TestFitALineBase):
def test_cpu(self): def test_cpu(self):
with self.program_scope_guard(): with self.program_scope_guard():
main(use_cuda=False) main(use_cuda=False)
...@@ -158,20 +189,17 @@ class TestFitALine(unittest.TestCase): ...@@ -158,20 +189,17 @@ class TestFitALine(unittest.TestCase):
with self.program_scope_guard(): with self.program_scope_guard():
main(use_cuda=True) main(use_cuda=True)
@unittest.skipIf(not fluid.core.supports_bfloat16(),
"place does not support BF16 evaluation") @unittest.skipIf(not fluid.core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestFitALineBF16(TestFitALineBase):
def test_bf16(self): def test_bf16(self):
with self.program_scope_guard(): with self.program_scope_guard():
main(use_cuda=False, use_bf16=True) main(use_cuda=False, use_bf16=True)
@contextlib.contextmanager def test_pure_bf16(self):
def program_scope_guard(self): with self.program_scope_guard():
prog = fluid.Program() main(use_cuda=False, use_bf16=True, pure_bf16=True)
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -44,7 +44,8 @@ def train(target, ...@@ -44,7 +44,8 @@ def train(target,
is_parallel, is_parallel,
save_dirname, save_dirname,
is_local=True, is_local=True,
use_bf16=False): use_bf16=False,
pure_bf16=False):
PASS_NUM = 100 PASS_NUM = 100
EMBED_SIZE = 32 EMBED_SIZE = 32
HIDDEN_SIZE = 256 HIDDEN_SIZE = 256
...@@ -107,7 +108,13 @@ def train(target, ...@@ -107,7 +108,13 @@ def train(target,
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
if use_bf16: if use_bf16:
paddle.static.amp.rewrite_program_bf16(fluid.default_main_program()) sgd_optimizer = paddle.static.amp.bf16.decorate_bf16(
sgd_optimizer,
amp_lists=paddle.static.amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'softmax', 'concat'}, ),
use_bf16_guard=False,
use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch( train_reader = paddle.batch(
...@@ -121,6 +128,8 @@ def train(target, ...@@ -121,6 +128,8 @@ def train(target,
def train_loop(main_program): def train_loop(main_program):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if pure_bf16:
sgd_optimizer.amp_init(exe.place)
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
...@@ -128,7 +137,7 @@ def train(target, ...@@ -128,7 +137,7 @@ def train(target,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost]) fetch_list=[avg_cost])
if avg_cost_np[0] < 5.0: if avg_cost_np[0] < 5.0:
if save_dirname is not None: if save_dirname is not None and not pure_bf16:
fluid.io.save_inference_model(save_dirname, [ fluid.io.save_inference_model(save_dirname, [
'firstw', 'secondw', 'thirdw', 'forthw' 'firstw', 'secondw', 'thirdw', 'forthw'
], [predict_word], exe) ], [predict_word], exe)
...@@ -246,7 +255,7 @@ def infer(target, save_dirname=None): ...@@ -246,7 +255,7 @@ def infer(target, save_dirname=None):
assert np.isclose(a, b, rtol=5e-5), "a: {}, b: {}".format(a, b) assert np.isclose(a, b, rtol=5e-5), "a: {}, b: {}".format(a, b)
def main(target, is_sparse, is_parallel, use_bf16): def main(target, is_sparse, is_parallel, use_bf16, pure_bf16):
if target == "cuda" and not fluid.core.is_compiled_with_cuda(): if target == "cuda" and not fluid.core.is_compiled_with_cuda():
return return
if target == "xpu" and not fluid.core.is_compiled_with_xpu(): if target == "xpu" and not fluid.core.is_compiled_with_xpu():
...@@ -265,7 +274,13 @@ def main(target, is_sparse, is_parallel, use_bf16): ...@@ -265,7 +274,13 @@ def main(target, is_sparse, is_parallel, use_bf16):
# so only inference is turned on. # so only inference is turned on.
train("cpu", is_sparse, is_parallel, save_dirname) train("cpu", is_sparse, is_parallel, save_dirname)
else: else:
train(target, is_sparse, is_parallel, save_dirname, use_bf16=use_bf16) train(
target,
is_sparse,
is_parallel,
save_dirname,
use_bf16=use_bf16,
pure_bf16=pure_bf16)
infer(target, save_dirname) infer(target, save_dirname)
...@@ -278,10 +293,15 @@ class W2VTest(unittest.TestCase): ...@@ -278,10 +293,15 @@ class W2VTest(unittest.TestCase):
pass pass
def inject_test_method(target, is_sparse, is_parallel, use_bf16=False): def inject_test_method(target,
is_sparse,
is_parallel,
use_bf16=False,
pure_bf16=False):
fn_name = "test_{0}_{1}_{2}{3}".format(target, "sparse" fn_name = "test_{0}_{1}_{2}{3}".format(target, "sparse"
if is_sparse else "dense", "parallel" if is_sparse else "dense", "parallel"
if is_parallel else "normal", "_bf16" if is_parallel else "normal",
"_purebf16" if pure_bf16 else "_bf16"
if use_bf16 else "") if use_bf16 else "")
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
...@@ -290,7 +310,7 @@ def inject_test_method(target, is_sparse, is_parallel, use_bf16=False): ...@@ -290,7 +310,7 @@ def inject_test_method(target, is_sparse, is_parallel, use_bf16=False):
scope = fluid.core.Scope() scope = fluid.core.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
main(target, is_sparse, is_parallel, use_bf16) main(target, is_sparse, is_parallel, use_bf16, pure_bf16)
if (not fluid.core.is_compiled_with_cuda() or if (not fluid.core.is_compiled_with_cuda() or
target == "cuda") and is_sparse: target == "cuda") and is_sparse:
...@@ -307,7 +327,8 @@ for target in ("cuda", "cpu", "xpu"): ...@@ -307,7 +327,8 @@ for target in ("cuda", "cpu", "xpu"):
for is_sparse in (False, True): for is_sparse in (False, True):
for is_parallel in (False, ): for is_parallel in (False, ):
inject_test_method(target, is_sparse, is_parallel) inject_test_method(target, is_sparse, is_parallel)
inject_test_method("cpu", False, False, use_bf16=True) inject_test_method("cpu", False, False, True)
inject_test_method("cpu", False, False, True, True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -64,7 +64,7 @@ class SimpleNetWithCond(object): ...@@ -64,7 +64,7 @@ class SimpleNetWithCond(object):
return grads return grads
def build_net(self, cond_i): def build_net(self, cond_i, use_bf16=False):
""" """
pseudo code: pseudo code:
sum_xy = x + y sum_xy = x + y
...@@ -122,13 +122,22 @@ class SimpleNetWithCond(object): ...@@ -122,13 +122,22 @@ class SimpleNetWithCond(object):
sum_cond = fluid.layers.cond(cond_i > 1.0, cond_true, cond_false) sum_cond = fluid.layers.cond(cond_i > 1.0, cond_true, cond_false)
sum_all = fluid.layers.sum([sum_xy, sub_yz, sum_cond]) sum_all = fluid.layers.sum([sum_xy, sub_yz, sum_cond])
mean_out = fluid.layers.mean(sum_all) mean_out = fluid.layers.mean(sum_all)
if use_bf16:
import paddle.static.amp as amp
self.optimizer = amp.bf16.decorate_bf16(
self.optimizer,
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'}),
use_bf16_guard=False,
use_pure_bf16=True)
self.optimizer.minimize(mean_out) self.optimizer.minimize(mean_out)
fetch_list = ["param_x", "param_z"] if self.y_no_grad else [ fetch_list = ["param_x", "param_z"] if self.y_no_grad else [
"param_x", "param_y", "param_z" "param_x", "param_y", "param_z"
] ]
fetch_list += [_append_grad_suffix_(param) for param in fetch_list] fetch_list += [_append_grad_suffix_(param) for param in fetch_list]
return fetch_list return fetch_list, self.optimizer
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
...@@ -180,7 +189,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -180,7 +189,7 @@ class TestOptimizer(unittest.TestCase):
for key in ['x', 'y', 'z']: for key in ['x', 'y', 'z']:
self.param_attr[key] = self.attr.copy() self.param_attr[key] = self.attr.copy()
def _check_grads(self): def _check_grads(self, use_bf16=False):
""" """
main logic code to check the validity of apply_optimize. main logic code to check the validity of apply_optimize.
""" """
...@@ -204,10 +213,16 @@ class TestOptimizer(unittest.TestCase): ...@@ -204,10 +213,16 @@ class TestOptimizer(unittest.TestCase):
lambda: dict()) lambda: dict())
test_net = self.NetClass(self.optimizer, param_lr, test_net = self.NetClass(self.optimizer, param_lr,
y_no_grad) y_no_grad)
fetch_list = test_net.build_net(cond_i) fetch_list, decorated_optimizer = test_net.build_net(
cond_i, use_bf16)
if use_bf16:
self.optimizer = decorated_optimizer
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(init_program) exe.run(init_program)
if use_bf16:
self.optimizer.amp_init(exe.place)
# Train 2 steps to check validity # Train 2 steps to check validity
for batch_i in range(2): for batch_i in range(2):
...@@ -222,6 +237,15 @@ class TestOptimizer(unittest.TestCase): ...@@ -222,6 +237,15 @@ class TestOptimizer(unittest.TestCase):
param_grads[i]) param_grads[i])
@unittest.skipIf(not fluid.core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestSGDOptimizer(TestOptimizer):
def test_optimizer_multiblock_except(self):
with self.assertRaisesRegexp(ValueError,
"var param_y not in this block"):
self._check_grads(use_bf16=True)
class TestAdamOptimizer(TestOptimizer): class TestAdamOptimizer(TestOptimizer):
""" """
inherit TestOptimizer and shall override two functions as follows: inherit TestOptimizer and shall override two functions as follows:
......
...@@ -18,7 +18,4 @@ from ...fluid.contrib.mixed_precision import AutoMixedPrecisionLists # noqa: F4 ...@@ -18,7 +18,4 @@ from ...fluid.contrib.mixed_precision import AutoMixedPrecisionLists # noqa: F4
from ...fluid.contrib.mixed_precision import fp16_guard # noqa: F401 from ...fluid.contrib.mixed_precision import fp16_guard # noqa: F401
from ...fluid.contrib.mixed_precision import cast_model_to_fp16 # noqa: F401 from ...fluid.contrib.mixed_precision import cast_model_to_fp16 # noqa: F401
from ...fluid.contrib.mixed_precision import cast_parameters_to_fp16 # noqa: F401 from ...fluid.contrib.mixed_precision import cast_parameters_to_fp16 # noqa: F401
from ...fluid.contrib.mixed_precision import AutoMixedPrecisionListsBF16 # noqa: F401 from ...fluid.contrib.mixed_precision import bf16 # noqa: F401
from ...fluid.contrib.mixed_precision import bf16_guard # noqa: F401
from ...fluid.contrib.mixed_precision import rewrite_program_bf16 # noqa: F401
from ...fluid.contrib.mixed_precision import convert_float_to_uint16 # noqa: F401
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册