未验证 提交 71a513c2 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support promote kernel for static graph (#52514)

* support promote dtype for static amp training

* unify o1 and o2

* update for unittest

* fix op_role

* add use_promote arg

* fix doc

* add promote unittest

* polish unittests

* fix controflow and test
上级 040f8aa5
...@@ -4982,8 +4982,8 @@ class PipelineOptimizer: ...@@ -4982,8 +4982,8 @@ class PipelineOptimizer:
device = post_op.attr(self._op_device_key) device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set." assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device) op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or op.type == "scale") and self._is_backward_op( elif (op.type == "cast" or op.type == "scale") and (
op self._is_backward_op(op) or self._is_forward_op(op)
): ):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
......
...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase): ...@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase): ...@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase): ...@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase): ...@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase): ...@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase): ...@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase): ...@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase): ...@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase):
optimizer.minimize(loss) optimizer.minimize(loss)
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = numpy.random.random(size=(2, 2)).astype('float16') x = numpy.random.random(size=(2, 2)).astype('float16')
else: else:
x = numpy.random.random(size=(2, 2)).astype('float32') x = numpy.random.random(size=(2, 2)).astype('float32')
......
...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase): ...@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase): ...@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if use_amp: if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase): ...@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase): ...@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program) exe.run(startup_program)
if mp: if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16') x = np.random.random(size=(2, 2)).astype('float16')
else: else:
x = np.random.random(size=(2, 2)).astype('float32') x = np.random.random(size=(2, 2)).astype('float32')
......
...@@ -294,8 +294,8 @@ class PartialProgramLayer: ...@@ -294,8 +294,8 @@ class PartialProgramLayer:
def _create_amp_program(self, is_infer_mode=False): def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode) amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program): with program_guard(amp_program):
paddle.static.amp.fp16_utils.rewrite_program( paddle.static.amp.fp16_utils.cast_model_to_fp16(
amp_program, self._amp_list amp_program, self._amp_list, use_fp16_guard=False, level='O1'
) )
if is_infer_mode: if is_infer_mode:
if self._hooker: if self._hooker:
......
...@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype ...@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import ( from .fp16_utils import (
cast_model_to_fp16, cast_model_to_fp16,
cast_parameters_to_fp16, cast_parameters_to_fp16,
rewrite_program,
update_role_var_grad, update_role_var_grad,
) )
from .function_overload import FunctionType, overload from .function_overload import FunctionType, overload
...@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision: ...@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision:
the loss scaling. the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`. Default None, which means that its value is equal to `use_pure_fp16`.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
""" """
def __init__( def __init__(
...@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision: ...@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision:
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
use_amp_guard=None, use_amp_guard=None,
use_promote=False,
): ):
self._optimizer = optimizer self._optimizer = optimizer
self._amp_lists = amp_lists self._amp_lists = amp_lists
...@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision: ...@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision:
self._decr_ratio = decr_ratio self._decr_ratio = decr_ratio
self._num_good_steps = None self._num_good_steps = None
self._num_bad_steps = None self._num_bad_steps = None
self.use_promote = use_promote
def _set_distributed(self, flag): def _set_distributed(self, flag):
# if distributed, all cards will communication with each other, # if distributed, all cards will communication with each other,
...@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision: ...@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
else: else:
rewrite_program( # use_fp16_guard is not support amp-o1.
self._train_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
self._train_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
...@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision: ...@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
self._use_fp16_guard, self._use_fp16_guard,
self._amp_vartype, self._amp_vartype,
level='O2',
use_promote=self.use_promote,
) )
elif use_fp16_test: elif use_fp16_test:
rewrite_program( # use_fp16_guard is not support amp-o1.
test_program, self._amp_lists, self._amp_vartype cast_model_to_fp16(
test_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
) )
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
...@@ -624,6 +642,7 @@ def decorate( ...@@ -624,6 +642,7 @@ def decorate(
use_pure_fp16=False, use_pure_fp16=False,
use_fp16_guard=None, use_fp16_guard=None,
use_bf16=False, use_bf16=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -736,6 +755,7 @@ def decorate( ...@@ -736,6 +755,7 @@ def decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard, use_amp_guard=use_fp16_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
...@@ -754,6 +774,7 @@ def decorate( ...@@ -754,6 +774,7 @@ def decorate(
decr_ratio=0.8, decr_ratio=0.8,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=True,
use_amp_guard=False, use_amp_guard=False,
use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -781,6 +802,7 @@ def decorate( ...@@ -781,6 +802,7 @@ def decorate(
incr_ratio=incr_ratio, incr_ratio=incr_ratio,
decr_ratio=decr_ratio, decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard, use_amp_guard=use_amp_guard,
use_promote=use_promote,
) )
return mp_optimizer return mp_optimizer
...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype): ...@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype):
else: else:
device = 'GPU' device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) _, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list return device, sys_unsupported_list
...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype): ...@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype):
return unsupported_list return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
class AutoMixedPrecisionLists: class AutoMixedPrecisionLists:
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists: ...@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists:
self.amp_dtype = check_amp_dtype(dtype) self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list) self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists: ...@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists:
""" """
Update black and white list according to users' custom list. Update black and white list according to users' custom list.
""" """
_logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
_logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
_logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists: ...@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists:
) )
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops. # dangerous and whose effects may also be observed in downstream ops.
black_list = { black_list = {
......
...@@ -29,6 +29,7 @@ def _build_optimizer( ...@@ -29,6 +29,7 @@ def _build_optimizer(
amp_level="O1", amp_level="O1",
amp_lists=None, amp_lists=None,
use_grad_clip=False, use_grad_clip=False,
use_promote=False,
): ):
if use_grad_clip: if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
...@@ -45,7 +46,11 @@ def _build_optimizer( ...@@ -45,7 +46,11 @@ def _build_optimizer(
) )
if use_amp: if use_amp:
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
optimizer, amp_lists, level=amp_level, dtype=amp_dtype optimizer,
amp_lists,
level=amp_level,
dtype=amp_dtype,
use_promote=use_promote,
) )
return optimizer return optimizer
...@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer): ...@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer):
return x + self.weight return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
else: else:
amp_lists = None amp_lists = None
optimizer = _build_optimizer( optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists use_amp,
amp_dtype,
amp_level,
amp_lists,
use_promote=use_promote,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
feed_vars = [x] feed_vars = [x]
...@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer): ...@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3) self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10) self.linear = nn.Linear(in_features=96, out_features=4)
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
out = nn.functional.relu(out) out = nn.functional.relu(out)
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out) out = self.linear(out)
out = nn.functional.softmax(out) out = nn.functional.softmax(out)
return out return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet() model = SimpleConvNet()
x = paddle.static.data( x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32' name='input', shape=[None, 1, 6, 6], dtype='float32'
) )
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level) optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, use_promote=use_promote
)
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleEmbeddingNet(nn.Layer): class SimpleEmbeddingNet(nn.Layer):
...@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer): ...@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer):
return out return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"): ...@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
out = model(x) out = model(x)
loss = paddle.mean(out) loss = paddle.mean(out)
optimizer = _build_optimizer( optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True use_amp,
amp_dtype,
amp_level,
None,
True,
use_promote=use_promote,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
return main_program, startup_program return main_program, startup_program
...@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase): ...@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.amp_dtype = None self.amp_dtype = None
self.amp_level = None self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
for op_type, value in expected_fp16_calls.items():
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
)
def run_program(
self,
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
):
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if level == 'O2':
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_promote_results(
True,
'float16',
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw": 4,
}
self.check_promote_results(
True,
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
if __name__ == '__main__':
unittest.main()
...@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase): class TestProgramBF16(AmpTestBase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self): def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1" True, "bfloat16", "O1"
...@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 0, "squared_l2_norm": 0,
"adamw": 0, "adamw": 0,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self): def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
...@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase): ...@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 2, "squared_l2_norm": 2,
"adamw": 2, "adamw": 2,
} }
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
class TestStaticBF16(AmpTestBase): class TestStaticBF16(AmpTestBase):
...@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase): ...@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase):
return x_fp32, x_bf16 return x_fp32, x_bf16
def test_compare_o1_o2(self): def test_compare_o1_o2(self):
def _run_o1(place, exe, x_np, max_iters): def _run(place, exe, x_np, max_iters, level):
( (
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O1") ) = build_add_model(True, "bfloat16", level)
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
def _run_o2(place, exe, x_np, max_iters): losses = self.run_program(
(
main_program, main_program,
startup_program, startup_program,
optimizer, optimizer,
feed_vars, feed_vars,
fetch_vars, fetch_vars,
) = build_add_model(True, "bfloat16", "O2") place,
exe,
losses = [] x_np,
scope = paddle.static.Scope() max_iters,
with paddle.static.scope_guard(scope): level,
exe.run(startup_program) )
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2 max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x() x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(place, exe, x_fp32, max_iters) place = paddle.CUDAPlace(0)
losses_o2 = _run_o2(place, exe, x_bf16, max_iters) exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase):
# infer(use_cuda, save_dirname) # infer(use_cuda, save_dirname)
def test_amp_lists(self): def test_amp_lists(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_1(self): def test_amp_lists_1(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_2(self): def test_amp_lists_2(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_3(self): def test_amp_lists_3(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_4(self): def test_amp_lists_4(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_5(self): def test_amp_lists_5(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list) self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_6(self): def test_amp_lists_6(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list) white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......
...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard(): with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program): with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8]) x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16")
conv2d = paddle.nn.Conv2D( conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC' 8, 32, 1, bias_attr=False, data_format='NHWC'
) )
...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase): ...@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005 before_out[0], after_out[0], rtol=1e-05, atol=0.005
) )
if __name__ == '__main__':
unittest.main()
...@@ -25,10 +25,10 @@ paddle.enable_static() ...@@ -25,10 +25,10 @@ paddle.enable_static()
def build_resnet50(use_amp=False): def build_resnet50(use_amp=False):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dtype = 'float16' if use_amp else 'float32'
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[32, 3, 224, 224], dtype='float32' name='image', shape=[32, 3, 224, 224], dtype=dtype
) )
label = paddle.static.data(name='label', shape=[32], dtype='int64') label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50() model = paddle.vision.models.resnet50()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册