未验证 提交 9ded5707 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel Performance] Support BF16 Training (#51285)

* update env setting

* update pass logic

* dist op support bf16

* backward cast update

* update setting

* update backward

* revert amp pass

* update fp16 backward logic

* register c_embedding bf16

* revert engine

* add unitest

* add unitest

* update unitest

* update cmake

* update math

* update math.py

* update unitest

* update unitest

* revise unitest

* revise unitest

* update unitest

* update unitest

* update unitest
上级 3094d475
...@@ -198,8 +198,14 @@ namespace plat = paddle::platform; ...@@ -198,8 +198,14 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding, REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>, ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>, ops::CEmbeddingCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingCUDAKernel<plat::float16>); ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad, REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>, ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>, ops::CEmbeddingGradCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
#endif
ops::CEmbeddingGradCUDAKernel<plat::float16>); ops::CEmbeddingGradCUDAKernel<plat::float16>);
...@@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ...@@ -63,6 +63,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
AMP = "amp" AMP = "amp"
set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
...@@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) ...@@ -72,15 +74,12 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", []) set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", []) set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", []) set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False) set_field_default_config(AMP, "use_bf16_guard", False)
######################################### #########################################
......
...@@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -455,7 +455,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_var, Out_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum', 'c_allreduce_sum',
) )
...@@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -645,7 +645,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -687,12 +687,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}, },
) )
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
......
...@@ -220,27 +220,26 @@ class Parallelizer: ...@@ -220,27 +220,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["enable_bf16"]: self._logger.info(
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config) "Applying AMP-{}-{} ...".format(
auto_parallel_bf16_pass.apply( config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_bf16_pass.get_loss() loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
elif config["use_pure_fp16"]:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss() loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) raise ValueError("AMP level should be one of o1, o2, o3")
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
# apply quantization pass # apply quantization pass
# The pass can be applied when mode must be 'train' # The pass can be applied when mode must be 'train'
......
...@@ -632,6 +632,7 @@ class AMPPass(PassBase): ...@@ -632,6 +632,7 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", []) self.set_attr("input_data", [])
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("dtype", "") # fp16/bf16
self._loss = None self._loss = None
self._loss_scaling = None self._loss_scaling = None
self._num_good_steps = None self._num_good_steps = None
...@@ -639,6 +640,8 @@ class AMPPass(PassBase): ...@@ -639,6 +640,8 @@ class AMPPass(PassBase):
self._loss = None self._loss = None
def _check_self(self): def _check_self(self):
if self.get_attr("dtype") not in ["float16", "bfloat16"]:
return False
if self.get_attr("init_loss_scaling") < 0: if self.get_attr("init_loss_scaling") < 0:
return False return False
if self.get_attr("incr_every_n_steps") < 0: if self.get_attr("incr_every_n_steps") < 0:
......
...@@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -49,6 +49,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_amp MODULES test_pass_amp) py_test_modules(test_pass_amp MODULES test_pass_amp)
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass)
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 26)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 10)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
...@@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None): ...@@ -37,7 +37,7 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"] amp.level = level
amp.use_optimizer_fp16 = level == "o3" amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
...@@ -39,7 +39,7 @@ def apply_pass(): ...@@ -39,7 +39,7 @@ def apply_pass():
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = True amp.level = "o2"
qat = dist_strategy.qat qat = dist_strategy.qat
qat.enable = True qat.enable = True
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase): ...@@ -28,6 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase): ...@@ -37,15 +39,11 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, []) self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, []) self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, []) self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False) self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase): ...@@ -102,7 +100,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"] amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
...@@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase): ...@@ -115,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True) self.assertEqual(amp.use_optimizer_fp16, True)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
import warnings import warnings
from sqlite3 import NotSupportedError
import paddle import paddle
import paddle.autograd as imperative_base import paddle.autograd as imperative_base
...@@ -217,7 +218,9 @@ def _squared_l2_norm(x): ...@@ -217,7 +218,9 @@ def _squared_l2_norm(x):
return _C_ops.squared_l2_norm(x) return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm' op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type) check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type
)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
...@@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args): ...@@ -557,6 +560,20 @@ def _allow_pure_fp16_global_norm_clip(*args):
return old_value return old_value
_allow_pure_bf16_global_norm_clip_flag = False
def _allow_pure_bf16_global_norm_clip(*args):
global _allow_pure_bf16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_bf16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_bf16_global_norm_clip_flag
_allow_pure_bf16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase): class ClipGradByGlobalNorm(ClipGradBase):
r""" r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
...@@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -720,6 +737,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list = []
sum_square_list_fp16 = [] sum_square_list_fp16 = []
sum_square_list_bf16 = []
sum_square_list_fp32 = [] sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'): with framework.name_scope('gradient_clip'):
for p, g in params_grads: for p, g in params_grads:
...@@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -735,17 +753,29 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16: if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_bf16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32: elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square) sum_square_list_fp32.append(sum_square)
else: else:
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
if len(sum_square_list_fp16) > 0 and len(sum_square_list_bf16) > 0:
raise NotSupportedError(
'FP16 and BF16 are not supported at the same time.'
)
# all parameters have been filterd out # all parameters have been filterd out
if ( if (
len(sum_square_list) len(sum_square_list)
+ len(sum_square_list_fp16) + len(sum_square_list_fp16)
+ len(sum_square_list_fp32) + len(sum_square_list_fp32)
== 0 == 0
) and (
len(sum_square_list)
+ len(sum_square_list_bf16)
+ len(sum_square_list_fp32)
== 0
): ):
return params_grads return params_grads
...@@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -765,6 +795,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
) )
else: else:
global_norm_var.append(global_norm_var_fp16) global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_bf16) > 0:
global_norm_var_bf16 = paddle.add_n(sum_square_list_bf16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_bf16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_bf16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_bf16)
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
...@@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -804,12 +846,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g) new_g = _cast_to_mp_type_if_enabled(g)
# inplace # inplace
scale_input = ( if (
scale_var.astype('float16') new_g.dtype == core.VarDesc.VarType.FP16
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16 and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var ):
) scale_input = scale_var.astype('float16')
elif (
new_g.dtype == core.VarDesc.VarType.BF16
and scale_var.dtype != core.VarDesc.VarType.BF16
):
scale_input = scale_var.astype('bfloat16')
else:
scale_input = scale_var
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops. # will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter # We need to handle the correct block, otherwise will encounter
......
...@@ -1657,14 +1657,21 @@ def add_n(inputs, name=None): ...@@ -1657,14 +1657,21 @@ def add_n(inputs, name=None):
check_variable_and_dtype( check_variable_and_dtype(
input, input,
"inputs", "inputs",
['float16', 'float32', 'float64', 'int32', 'int64'], [
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'add_n', 'add_n',
) )
else: else:
check_variable_and_dtype( check_variable_and_dtype(
inputs, inputs,
"inputs", "inputs",
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'add_n', 'add_n',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册