未验证 提交 9ded5707 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel Performance] Support BF16 Training (#51285)

* update env setting

* update pass logic

* dist op support bf16

* backward cast update

* update setting

* update backward

* revert amp pass

* update fp16 backward logic

* register c_embedding bf16

* revert engine

* add unitest

* add unitest

* update unitest

* update cmake

* update math

* update math.py

* update unitest

* update unitest

* revise unitest

* revise unitest

* update unitest

* update unitest

* update unitest
上级 3094d475
......@@ -198,8 +198,14 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingGradCUDAKernel<plat::float16>);
......@@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
......@@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False)
#########################################
......
......@@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype(
Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum',
)
......@@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
......
......@@ -220,27 +220,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"]
)
if config["enable_bf16"]:
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config)
auto_parallel_bf16_pass.apply(
self._logger.info(
"Applying AMP-{}-{} ...".format(
config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_bf16_pass.get_loss()
elif config["use_pure_fp16"]:
loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
raise ValueError("AMP level should be one of o1, o2, o3")
# apply quantization pass
# The pass can be applied when mode must be 'train'
......
......@@ -632,6 +632,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", [])
self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None
self._loss_scaling = None
self._num_good_steps = None
......@@ -639,6 +640,8 @@ class AMPPass(PassBase):
self._loss = None
def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0:
return False
if self.get_attr("incr_every_n_steps") < 0:
......
......@@ -29,13 +29,9 @@ from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.framework import core
from paddle.static import default_main_program, default_startup_program
from paddle.static.amp.fp16_utils import (
AutoMixedPrecisionLists,
_dtype_to_str,
_keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32,
_valid_types,
)
# NOTE bf16 and fp16 may have diff logic for _keep_layer_norm_scale_bias_to_fp32
from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh
......@@ -50,6 +46,8 @@ __amp_skip_ops__ = [
'while',
'cast',
]
__target_dtype__ = None
__amp_utils__ = None
def set_op_dtype_to_fp16(op):
......@@ -57,17 +55,24 @@ def set_op_dtype_to_fp16(op):
op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
op._set_attr('in_dtype', __target_dtype__)
if (
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
op._set_attr('dtype', __target_dtype__)
if __target_dtype__ == core.VarDesc.VarType.BF16:
if op.has_attr('use_mkldnn'):
op._set_attr('use_mkldnn', True)
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
# adapot for backward op
# TODO check if bf16 and fp16 still share the same logic
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type == 'batch_norm':
......@@ -96,6 +101,7 @@ def _keep_fp32_input(op, in_name):
return False
# TODO check if bf16 and fp16 still share the same logic
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation']:
......@@ -208,7 +214,7 @@ class FP16State:
self._op_fp16_dict[op.desc.original_id()] = True
return
if _need_keep_fp32(
if __amp_utils__._need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard
):
self._op_fp16_dict[op.desc.original_id()] = False
......@@ -240,11 +246,15 @@ class FP16State:
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
if var is None or var.type not in _valid_types or "array_" in var_name:
if (
var is None
or var.type not in __amp_utils__._valid_types
or "array_" in var_name
):
return
if var.dtype == core.VarDesc.VarType.FP32:
var.desc.set_dtype(core.VarDesc.VarType.FP16)
var.desc.set_dtype(__target_dtype__)
def resolute_tensor_dtype(self, block):
......@@ -274,9 +284,12 @@ class FP16State:
elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
if (
out_var is None
or out_var.type not in __amp_utils__._valid_types
):
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
elif is_backward_op(op):
if self._is_fp16_op(op.desc.original_id()) is True:
......@@ -290,9 +303,12 @@ class FP16State:
elif self._is_fp16_op(op.desc.original_id()) is False:
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
if (
out_var is None
or out_var.type not in __amp_utils__._valid_types
):
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == __target_dtype__:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
def cast_block(self, block):
......@@ -311,7 +327,7 @@ class FP16State:
op,
idx,
block,
core.VarDesc.VarType.FP16,
__target_dtype__,
core.VarDesc.VarType.FP32,
self.dist_context,
)
......@@ -321,7 +337,7 @@ class FP16State:
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
__target_dtype__,
self.dist_context,
)
elif is_backward_op(op):
......@@ -331,7 +347,7 @@ class FP16State:
op,
idx,
block,
core.VarDesc.VarType.FP16,
__target_dtype__,
core.VarDesc.VarType.FP32,
self.dist_context,
)
......@@ -341,7 +357,7 @@ class FP16State:
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
__target_dtype__,
self.dist_context,
)
elif op.type == "sum":
......@@ -379,14 +395,16 @@ class FP16State:
in_var = block._find_var_recursive(in_var_name)
if (
in_var is None
or in_var.type not in _valid_types
or in_var.type not in __amp_utils__._valid_types
or in_var.dtype == dst_dtype
):
continue
if in_var.dtype == src_dtype:
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
in_var.name
+ '.cast_'
+ __amp_utils__._dtype_to_str(dst_dtype)
)
cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op.desc.original_id()] += [
......@@ -476,14 +494,15 @@ class FP16State:
slot_name,
) in self.forward_input_cast_ops[forward_op_id]:
# rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if slot_name not in op.input_names:
continue
if slot_name in op.input_names:
# rename input
assert src_name in op.input(
slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op)
)
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name)
......@@ -491,9 +510,7 @@ class FP16State:
# create cast grad
grad_slot_name = slot_name + "@GRAD"
if grad_slot_name not in op.output_names:
continue
if grad_slot_name in op.output_names:
# some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0:
continue
......@@ -521,7 +538,9 @@ class FP16State:
cast_grad, grad_dist_attr
)
op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)
grad_op_attr.set_output_dist_attr(
cast_grad.name, grad_dist_attr
)
# add cast
cast_op = block._insert_op_without_sync(
......@@ -604,7 +623,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def _split_grads(params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
), "Data types of all grads must be either fp16 or fp32."
......@@ -707,17 +726,17 @@ def cast_startup_program():
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if (
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
if param_to_dtype.get(output_name, None) == __target_dtype__:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op)
)
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
op._set_attr('dtype', __target_dtype__)
@register_pass("auto_parallel_fp16")
......@@ -730,9 +749,37 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification.
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
self.target_dtype = self.get_attr("dtype")
params_grads = self.get_attr("params_grads")
amp_list = AutoMixedPrecisionLists(
self.use_optimizer_fp16 = self.get_attr("use_optimizer_fp16", None)
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
import paddle.static.amp.fp16_utils as amp_utils
AMPList = amp_utils.AutoMixedPrecisionLists
__target_dtype = core.VarDesc.VarType.FP16
elif self.target_dtype == "bfloat16":
import paddle.static.amp.bf16.amp_utils as amp_utils
AMPList = amp_utils.AutoMixedPrecisionListsBF16
__target_dtype = core.VarDesc.VarType.BF16
else:
raise NotImplementedError(
"target dtype [{}] is for amp o2 not supported yet.".format(
self.target_dtype
)
)
global __target_dtype__
__target_dtype__ = __target_dtype
global __amp_utils__
__amp_utils__ = amp_utils
amp_list = AMPList(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
None,
......@@ -747,7 +794,9 @@ class FP16Pass(AMPPass):
main_program,
amp_list,
self.dist_context,
self.get_attr("use_fp16_guard"),
self.get_attr(
"use_fp16_guard"
), # TODO unify to use_amp_guard to be compatible with amp o1
input_data_var_names,
)
is_train = fp16_state._build_state()
......@@ -755,6 +804,7 @@ class FP16Pass(AMPPass):
cast_startup_program()
if is_train:
if self.target_dtype == "fp16":
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
......@@ -864,11 +914,12 @@ class FP16Pass(AMPPass):
# modify optimizer
base_opt = self.get_attr("base_opt")
base_opt._multi_precision = True
if self.get_attr("use_optimizer_fp16"):
if self.use_optimizer_fp16:
base_opt._multi_precision = False
if self.target_dtype == "fp16":
if isinstance(
base_opt,
(paddle.static.Adam, paddle.optimizer.AdamW),
base_opt, (paddle.static.Adam, paddle.optimizer.AdamW)
):
with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy(
......
......@@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_amp MODULES test_pass_amp)
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass)
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 26)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 10)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
......@@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None):
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.level = level
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
......
......@@ -39,7 +39,7 @@ def apply_pass():
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
amp.level = "o2"
qat = dist_strategy.qat
qat.enable = True
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
......@@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding
......@@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
......@@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
......
......@@ -14,6 +14,7 @@
import copy
import warnings
from sqlite3 import NotSupportedError
import paddle
import paddle.autograd as imperative_base
......@@ -217,7 +218,9 @@ def _squared_l2_norm(x):
return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type)
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type
)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......@@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value
_allow_pure_bf16_global_norm_clip_flag = False
def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
......@@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
......@@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
) and (
len(sum_square_list)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
......@@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
......@@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
if (
new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var
)
):
scale_input = scale_var.astype('float16')
elif (
new_g.dtype == core.VarDesc.VarType.BF16
and scale_var.dtype != core.VarDesc.VarType.BF16
):
scale_input = scale_var.astype('bfloat16')
else:
scale_input = scale_var
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
......
......@@ -1657,14 +1657,21 @@ def add_n(inputs, name=None):
check_variable_and_dtype(
input,
"inputs",
['float16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'add_n',
)
else:
check_variable_and_dtype(
inputs,
"inputs",
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'add_n',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册