未验证 提交 71a513c2 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support promote kernel for static graph (#52514)

* support promote dtype for static amp training

* unify o1 and o2

* update for unittest

* fix op_role

* add use_promote arg

* fix doc

* add promote unittest

* polish unittests

* fix controflow and test
上级 040f8aa5
...@@ -4982,8 +4982,8 @@ class PipelineOptimizer: ...@@ -4982,8 +4982,8 @@ class PipelineOptimizer:
device = post_op.attr(self._op_device_key) device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set." assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or op.type == "scale") and self._is_backward_op( elif (op.type == "cast" or op.type == "scale") and (
op self._is_backward_op(op) or self._is_forward_op(op)
): ):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
......
...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase): ...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase): ...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase): ...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase): ...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase): ...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase): ...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase): ...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase): ...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = numpy.random.random(size=(2, 2)).astype('float16') x = numpy.random.random(size=(2, 2)).astype('float16')
else: else:
x = numpy.random.random(size=(2, 2)).astype('float32') x = numpy.random.random(size=(2, 2)).astype('float32')
......
...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase): ...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase): ...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase): ...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase): ...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -294,8 +294,8 @@ class PartialProgramLayer: ...@@ -294,8 +294,8 @@ class PartialProgramLayer:
def _create_amp_program(self, is_infer_mode=False): def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode) amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program): with program_guard(amp_program):
paddle.static.amp.fp16_utils.rewrite_program( paddle.static.amp.fp16_utils.cast_model_to_fp16(
amp_program, self._amp_list amp_program, self._amp_list, use_fp16_guard=False, level='O1'
) )
if is_infer_mode: if is_infer_mode:
if self._hooker: if self._hooker:
......
...@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype ...@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import ( from .fp16_utils import (
cast_model_to_fp16, cast_model_to_fp16,
cast_parameters_to_fp16, cast_parameters_to_fp16,
rewrite_program,
update_role_var_grad, update_role_var_grad,
) )
from .function_overload import FunctionType, overload from .function_overload import FunctionType, overload
...@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision: ...@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision:
the loss scaling. the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`. Default None, which means that its value is equal to `use_pure_fp16`.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
""" """
def __init__( def __init__(
...@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision: ...@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision:
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
use_amp_guard=None, use_amp_guard=None,
use_promote=False,
): ):
self._optimizer = optimizer self._optimizer = optimizer
self._amp_lists = amp_lists self._amp_lists = amp_lists
...@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision: ...@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision:
self._decr_ratio = decr_ratio self._decr_ratio = decr_ratio
self._num_good_steps = None self._num_good_steps = None
self._num_bad_steps = None self._num_bad_steps = None
self.use_promote = use_promote
def _set_distributed(self, flag): def _set_distributed(self, flag):
# if distributed, all cards will communication with each other, # if distributed, all cards will communication with each other,
...@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision: ...@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
else: else:
rewrite_program( # use_fp16_guard is not support amp-o1.
self._train_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
self._train_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
...@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision: ...@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
elif use_fp16_test: elif use_fp16_test:
rewrite_program( # use_fp16_guard is not support amp-o1.
test_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
test_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
...@@ -624,6 +642,7 @@ def decorate( ...@@ -624,6 +642,7 @@ def decorate(
use_pure_fp16=False, use_pure_fp16=False,
use_fp16_guard=None, use_fp16_guard=None,
use_bf16=False, use_bf16=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -736,6 +755,7 @@ def decorate( ...@@ -736,6 +755,7 @@ def decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard, use_amp_guard=use_fp16_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
...@@ -754,6 +774,7 @@ def decorate( ...@@ -754,6 +774,7 @@ def decorate(
decr_ratio=0.8, decr_ratio=0.8,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=True,
use_amp_guard=False, use_amp_guard=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -781,6 +802,7 @@ def decorate( ...@@ -781,6 +802,7 @@ def decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard, use_amp_guard=use_amp_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype): ...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype):
else: else:
device = 'GPU' device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) _, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list return device, sys_unsupported_list
...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype): ...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype):
return unsupported_list return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
class AutoMixedPrecisionLists: class AutoMixedPrecisionLists:
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists: ...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists:
self.amp_dtype = check_amp_dtype(dtype) self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list) self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists: ...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists:
""" """
Update black and white list according to users' custom list. Update black and white list according to users' custom list.
""" """
_logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
_logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
_logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists: ...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists:
) )
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops. # dangerous and whose effects may also be observed in downstream ops.
black_list = { black_list = {
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import logging import logging
import numpy as np import numpy as np
...@@ -22,7 +21,11 @@ from paddle.fluid import core, framework, global_scope ...@@ -22,7 +21,11 @@ from paddle.fluid import core, framework, global_scope
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from .fp16_lists import AutoMixedPrecisionLists, get_low_precision_dtypestr from .fp16_lists import (
AutoMixedPrecisionLists,
black_list,
get_low_precision_dtypestr,
)
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
...@@ -144,7 +147,7 @@ def _keep_fp32_output(op, out_name): ...@@ -144,7 +147,7 @@ def _keep_fp32_output(op, out_name):
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
""" """
Insert cast op and rename args of input and output. Insert cast op and rename op's input.
Args: Args:
block (Program): The block in which the operator is. block (Program): The block in which the operator is.
...@@ -167,8 +170,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -167,8 +170,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
in_var = block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype: if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
continue continue
if in_var.dtype == src_dtype: # op's input is already casted to dest_dtype before. Set the in_var.name to cast_name.
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
casted_var = block._find_var_recursive(cast_name)
if casted_var and casted_var.dtype == dest_dtype:
_rename_arg(op, in_var.name, casted_var.name)
continue
# insert cast for op's input.
if in_var.dtype == src_dtype:
out_var = block.vars.get(cast_name) out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype: if out_var is None or out_var.dtype != dest_dtype:
op_device = op.attr('op_device') op_device = op.attr('op_device')
...@@ -206,6 +216,13 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -206,6 +216,13 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
stop_gradient=in_var.stop_gradient, stop_gradient=in_var.stop_gradient,
) )
# Only forward program will be inserted cast op, but some ops
# has no op_role attr, so here set it direcly. eg. resnet_unit.
op_role = (
int(core.op_proto_and_checker_maker.OpRole.Forward)
if not op.has_attr('op_role')
else op.attr('op_role')
)
block._insert_op_without_sync( block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
...@@ -215,70 +232,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -215,70 +232,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": out_var.dtype, "out_dtype": out_var.dtype,
"op_device": op_device, "op_device": op_device,
"op_role": op.attr("op_role"), "op_role": op_role,
}, },
) )
num_cast_ops += 1 num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name) _rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', dest_dtype)
return num_cast_ops
def _insert_cast_post_op(
block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map
):
num_cast_ops = 0
target_var = block.var(target_name) for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if target_var.type not in _valid_types or target_var.dtype == dest_dtype: if op.has_attr(attr_name) and is_float_dtype(op.attr(attr_name)):
return num_cast_ops op._set_attr(attr_name, dest_dtype)
assert (
target_var.dtype == src_dtype
), "The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)
)
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient,
)
block._insert_op(
idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
},
)
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
return num_cast_ops return num_cast_ops
...@@ -420,11 +382,204 @@ def fp16_guard(): ...@@ -420,11 +382,204 @@ def fp16_guard():
yield yield
def is_float_dtype(dtype):
return (
dtype == core.VarDesc.VarType.FP32
or dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
or dtype == core.VarDesc.VarType.FP64
)
def set_var_dst_dtype(
op, var_names, block, global_block, dtype, need_set_dtype
):
low_precison_var_names = set()
for var_name in var_names:
var = None
try:
var = block._var_recursive(var_name)
except ValueError as e:
_logger.debug(f"-- {e}, try to get it in the global block --")
var = global_block.var(var_name)
if var is not None:
_logger.debug(
f"-- var {var_name} is got in the global block --"
)
if var is None or var.type not in _valid_types:
continue
if is_float_dtype(var.dtype):
low_precison_var_names.add(var_name)
if need_set_dtype:
var.desc.set_dtype(dtype)
_logger.debug(
"---- op type: {}, var name: {}, var dtype: {} ----".format(
op.type, var_name, var.dtype
)
)
return low_precison_var_names
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if level == "O1":
return
keep_fp32_var_names = set()
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
ops = block.ops
for op in ops:
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
for in_name in op.input_names:
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
)
else:
for in_name in op.input_names:
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name
):
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
)
for param in all_parameters:
if param.name not in keep_fp32_var_names:
_logger.debug(f"-- set param {param.name} to {dtype} --.")
param.desc.set_dtype(dtype)
def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
need_keep_fp32 = False
if _need_keep_fp32(
op,
amp_lists.unsupported_list,
use_fp16_guard,
):
need_keep_fp32 = True
elif amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists
):
need_keep_fp32 = True
elif op.type in amp_lists.black_list:
need_keep_fp32 = True
return need_keep_fp32
def get_promote_dtype(op, amp_dtype, block):
dst_dtype = amp_dtype
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name):
_logger.debug(
"---- Input {} {} should be kept fp32 ----".format(
in_name, op.input(in_name)
)
)
continue
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var and in_var.dtype == core.VarDesc.VarType.FP32:
dst_dtype = core.VarDesc.VarType.FP32
break
else:
dst_dtype = core.VarDesc.VarType.FP32
return dst_dtype
def get_amp_dst_dtype(
op, amp_dtype, level, block, amp_lists, keep_fp32_ops, keep_fp16_ops
):
if level == 'O2':
return amp_dtype
ops = block.ops
dst_dtype = amp_dtype
if op.type in amp_lists.gray_list:
keep_fp32 = False
keep_fp16 = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
prev_op in keep_fp32_ops
or prev_op.type in amp_lists.black_list
):
dst_dtype = core.VarDesc.VarType.FP32
elif (
prev_op in keep_fp16_ops
or prev_op.type in amp_lists.white_list
):
dst_dtype = amp_dtype
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
dst_dtype = core.VarDesc.VarType.FP32
return dst_dtype
def process_op_input_and_outputs(op, block, global_block, dtype):
low_precison_var_names = set()
# Get the FP16 input because the low_precison_var_names is required for the parameter casting.
# The dtype of the input is not set to fp16, because it is done in the step 3 of cast_model_to_fp16.
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name):
continue
in_vars = set_var_dst_dtype(
op,
op.input(in_name),
block,
global_block,
dtype,
need_set_dtype=False,
)
low_precison_var_names = low_precison_var_names.union(in_vars)
# Set the output to FP16 because its consumer OP needs to determine if the dtype needs
# to be promoted.
for out_name in op.output_names:
# for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output(op, out_name):
continue
set_var_dst_dtype(
op,
op.output(out_name),
block,
global_block,
dtype,
need_set_dtype=True,
)
return low_precison_var_names
def cast_model_to_fp16( def cast_model_to_fp16(
program, program,
amp_lists=None, amp_lists=None,
use_fp16_guard=True, use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16, dest_type=core.VarDesc.VarType.FP16,
level='O2',
use_promote=False,
): ):
""" """
Traverse all ops in the whole model and set their inputs and outputs Traverse all ops in the whole model and set their inputs and outputs
...@@ -438,158 +593,132 @@ def cast_model_to_fp16( ...@@ -438,158 +593,132 @@ def cast_model_to_fp16(
constructing the program. Default True. constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16. dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
_logger.debug("---- before cast model to fp16 ----")
_logger.debug(program)
if amp_lists is None: if amp_lists is None:
dtype = get_low_precision_dtypestr(dest_type) dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype) amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.unsupported_list -= {
"conditional_block_grad", # For amp o2 there is no blacklist by default.
"conditional_block", if level == 'O2':
"conditional_block_infer", amp_lists.black_list = amp_lists.black_list - black_list
"select_input",
"while",
"while_grad",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
global_block = program.global_block() global_block = program.global_block()
keep_fp32_ops = set() keep_fp32_ops = set()
keep_fp16_ops = set()
to_fp16_var_names = set() to_fp16_var_names = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)
# step 1: set params dtype.
set_param_dtype(
program,
dtype=dest_type,
amp_lists=amp_lists,
use_fp16_guard=use_fp16_guard,
level=level,
)
def need_process(op):
need_process = True
if op.type in ["cast", "create_py_reader", "read"]:
need_process = False
else:
for attr_name in ['out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(
op.attr(attr_name)
):
need_process = False
return need_process
# step 2: divide op into different sets according to the black/unsupported and white lists.
for block in program.blocks: for block in program.blocks:
ops = block.ops ops = block.ops
for op in ops: for op in ops:
if op.type == 'create_py_reader' or op.type == 'read': _logger.debug(f"-- process op: {op} --")
if not need_process(op):
_logger.debug("---- The op does not need to be processed ----.")
continue continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard): if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
keep_fp32_ops.add(op) keep_fp32_ops.add(op)
continue # processed below process_op_input_and_outputs(
for in_name in op.input_names: op, block, global_block, core.VarDesc.VarType.FP32
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name
):
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block._var_recursive(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".format(
e
)
) )
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug( _logger.debug(
"-- var {} is got in the global block --".format( "---- Add into keep_fp32_ops because the op needs to be kept fp32 ----"
in_var_name
) )
elif op.type in amp_lists.white_list:
keep_fp16_ops.add(op)
# get fp16 inputs and set op's outputs to fp16 for promote judgments
fp16_var_names = process_op_input_and_outputs(
op, block, global_block, dest_type
) )
to_fp16_var_names = to_fp16_var_names.union(fp16_var_names)
if in_var is None or in_var.type not in _valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name)
_logger.debug( _logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".format( "---- Add into keep_fp16_ops because the op in white_list ----"
op.type, in_var_name, in_var.dtype
) )
else:
# divide others ops into fp16/fp32 sets according to promoting principle.
dst_dtype = dest_type
if not use_promote:
dst_dtype = get_amp_dst_dtype(
op,
dest_type,
level,
block,
amp_lists,
keep_fp32_ops,
keep_fp16_ops,
) )
else:
dst_dtype = get_promote_dtype(op, dest_type, block)
for out_name in op.output_names: if dst_dtype == dest_type:
# for ipu, all outputs must be converted to fp16 keep_fp16_ops.add(op)
if not core.is_compiled_with_ipu() and _keep_fp32_output( fp16_var_names = process_op_input_and_outputs(
op, out_name op, block, global_block, dest_type
):
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block._var_recursive(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".format(
e
)
) )
out_var = global_block.var(out_var_name) to_fp16_var_names = to_fp16_var_names.union(fp16_var_names)
if out_var is not None:
_logger.debug( _logger.debug(
"-- var {} is got in the global block --".format( "---- Add into keep_fp16_ops because it should be promoted to fp16 ----"
out_var_name
) )
else:
keep_fp32_ops.add(op)
process_op_input_and_outputs(
op, block, global_block, core.VarDesc.VarType.FP32
) )
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(dest_type)
_logger.debug( _logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".format( "---- Add into keep_fp32_ops because it should be promoted to fp32 ----"
op.type, out_var_name, out_var.dtype
) )
)
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
):
op._set_attr(attr_name, dest_type)
# process ops in keep_fp32_ops # step 3: insert cast op for op's inputs.
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
for block in program.blocks: for block in program.blocks:
ops = block.ops ops = block.ops
idx = 0 idx = 0
while idx < len(ops): while idx < len(ops):
op = ops[idx] op = ops[idx]
num_cast_ops = 0 num_cast_ops = 0
if op in keep_fp32_ops: if op in keep_fp16_ops:
pre_cast_num = _insert_cast_op( in_var_cast_num = _insert_cast_op(
block, block,
op, op,
idx, idx,
dest_type,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
dest_type,
) )
num_cast_ops += pre_cast_num num_cast_ops += in_var_cast_num
for out_var_name in op.output_arg_names: if op in keep_fp32_ops:
out_var = block.vars.get(out_var_name) in_var_cast_num = _insert_cast_op(
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, block,
op, op,
idx + pre_cast_num + 1, idx,
core.VarDesc.VarType.FP32,
dest_type, dest_type,
out_var_name, core.VarDesc.VarType.FP32,
op_var_rename_map,
) )
num_cast_ops += post_cast_num num_cast_ops += in_var_cast_num
idx += num_cast_ops + 1
_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) idx += num_cast_ops + 1
_logger.debug("---- after cast model to fp16 ----")
_logger.debug(program)
return to_fp16_var_names return to_fp16_var_names
...@@ -646,108 +775,6 @@ def cast_parameters_to_fp16( ...@@ -646,108 +775,6 @@ def cast_parameters_to_fp16(
_logger.warning(f"Cannot find {param.name}") _logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
block = main_prog.global_block()
block._sync_with_cpp()
ops = block.ops
white_op_set = set()
black_op_set = set()
for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists
):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
white_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
prev_op in black_op_set
or prev_op.type in amp_lists.black_list
):
is_black_op = True
elif (
prev_op in white_op_set
or prev_op.type in amp_lists.white_list
):
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(
block, op, idx, dest_type, core.VarDesc.VarType.FP32
)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(
block, op, idx, core.VarDesc.VarType.FP32, dest_type
)
else:
pass
idx += num_cast_ops + 1
def update_role_var_grad(main_prog, params_grads): def update_role_var_grad(main_prog, params_grads):
""" """
Update op_role_var attr for some ops to make sure the gradients Update op_role_var attr for some ops to make sure the gradients
......
...@@ -29,6 +29,7 @@ def _build_optimizer( ...@@ -29,6 +29,7 @@ def _build_optimizer(
amp_level="O1", amp_level="O1",
amp_lists=None, amp_lists=None,
use_grad_clip=False, use_grad_clip=False,
use_promote=False,
): ):
if use_grad_clip: if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
...@@ -45,7 +46,11 @@ def _build_optimizer( ...@@ -45,7 +46,11 @@ def _build_optimizer(
) )
if use_amp: if use_amp:
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
optimizer, amp_lists, level=amp_level, dtype=amp_dtype optimizer,
amp_lists,
level=amp_level,
dtype=amp_dtype,
use_promote=use_promote,
) )
return optimizer return optimizer
...@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer): ...@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer):
return x + self.weight return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
else: else:
amp_lists = None amp_lists = None
optimizer = _build_optimizer( optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists use_amp,
amp_dtype,
amp_level,
amp_lists,
use_promote=use_promote,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
feed_vars = [x] feed_vars = [x]
...@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer): ...@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10) self.linear = nn.Linear(in_features=96, out_features=4)
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
out = nn.functional.relu(out) out = nn.functional.relu(out)
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out) out = self.linear(out)
out = nn.functional.softmax(out) out = nn.functional.softmax(out)
return out return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet() model = SimpleConvNet()
x = paddle.static.data( x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32' name='input', shape=[None, 1, 6, 6], dtype='float32'
) )
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, use_promote=use_promote
)
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleEmbeddingNet(nn.Layer): class SimpleEmbeddingNet(nn.Layer):
...@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer): ...@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer):
return out return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer( optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True use_amp,
amp_dtype,
amp_level,
None,
True,
use_promote=use_promote,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program return main_program, startup_program
...@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase): ...@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.amp_dtype = None self.amp_dtype = None
self.amp_level = None self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
for op_type, value in expected_fp16_calls.items():
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
)
def run_program(
self,
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
):
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if level == 'O2':
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_promote_results(
True,
'float16',
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw": 4,
}
self.check_promote_results(
True,
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
if __name__ == '__main__':
unittest.main()
...@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase): class TestProgramBF16(AmpTestBase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self): def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1" True, "bfloat16", "O1"
...@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 0, "squared_l2_norm": 0,
"adamw": 0, "adamw": 0,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self): def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
...@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 2, "squared_l2_norm": 2,
"adamw": 2, "adamw": 2,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
class TestStaticBF16(AmpTestBase): class TestStaticBF16(AmpTestBase):
...@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase): ...@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase):
return x_fp32, x_bf16 return x_fp32, x_bf16
def test_compare_o1_o2(self): def test_compare_o1_o2(self):
def _run_o1(place, exe, x_np, max_iters): def _run(place, exe, x_np, max_iters, level):
( (
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O1") ) = build_add_model(True, "bfloat16", level)
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
def _run_o2(place, exe, x_np, max_iters): losses = self.run_program(
(
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O2") place,
exe,
losses = [] x_np,
scope = paddle.static.Scope() max_iters,
with paddle.static.scope_guard(scope): level,
exe.run(startup_program)
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
) )
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2 max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x() x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(place, exe, x_fp32, max_iters) place = paddle.CUDAPlace(0)
losses_o2 = _run_o2(place, exe, x_bf16, max_iters) exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase):
# infer(use_cuda, save_dirname) # infer(use_cuda, save_dirname)
def test_amp_lists(self): def test_amp_lists(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_1(self): def test_amp_lists_1(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_2(self): def test_amp_lists_2(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_3(self): def test_amp_lists_3(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_4(self): def test_amp_lists_4(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_5(self): def test_amp_lists_5(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_6(self): def test_amp_lists_6(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......
...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard(): with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program): with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8]) x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16")
conv2d = paddle.nn.Conv2D( conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC' 8, 32, 1, bias_attr=False, data_format='NHWC'
) )
...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005 before_out[0], after_out[0], rtol=1e-05, atol=0.005
) )
if __name__ == '__main__':
unittest.main()
...@@ -25,10 +25,10 @@ paddle.enable_static() ...@@ -25,10 +25,10 @@ paddle.enable_static()
def build_resnet50(use_amp=False): def build_resnet50(use_amp=False):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dtype = 'float16' if use_amp else 'float32'
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[32, 3, 224, 224], dtype='float32' name='image', shape=[32, 3, 224, 224], dtype=dtype
) )
label = paddle.static.data(name='label', shape=[32], dtype='int64') label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50() model = paddle.vision.models.resnet50()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册