未验证 提交 8db15a42 编写于 作者: G Ghost Screaming 提交者: GitHub

Fix hybrid parallel training strategy using bf16 (#51103)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Remove climits.

* Fix bug of hybrid parallel strategy with recompute using bf16.

* Fix bug of recompute_hybrid ctx.amp_dtype

* Fix bug of amp_dtype.

* Fix bug of auto_cast.
上级 c191b707
......@@ -353,8 +353,11 @@ def amp_guard(
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype should be 'float16' or 'bfloat16'.")
if enable:
if not (dtype in ['float16', 'bfloat16']):
raise ValueError(
"If enable amp, dtype should be 'float16' or 'bfloat16'."
)
# check tracer
tracer = _dygraph_tracer()
......
......@@ -21,6 +21,7 @@ FLOAT_TYPE_DICT = {
paddle.float16: "float16",
paddle.float32: "float32",
paddle.float64: "float64",
paddle.bfloat16: "bfloat16",
}
PADDLE_TO_NUMBER = {
......@@ -29,6 +30,7 @@ PADDLE_TO_NUMBER = {
paddle.float64: 2,
paddle.int32: 3,
paddle.int64: 4,
paddle.bfloat16: 5,
}
NUMBER_TO_DTYPE = {
......@@ -37,6 +39,7 @@ NUMBER_TO_DTYPE = {
2: "float64",
3: "int32",
4: "int64",
5: "bfloat16",
}
......
......@@ -128,6 +128,7 @@ class _HPRecomputeFunction(PyLayer):
raise ValueError(
"unsupported amp level: {}".format(tracer._amp_level)
)
ctx.amp_dtype = tracer._amp_dtype
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
......@@ -203,12 +204,19 @@ class _HPRecomputeFunction(PyLayer):
with swith_rng_state_tracker(
ctx.fwd_rng_state, ctx.fwd_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
):
if ctx.is_fw_autocast:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(
*detached_inputs, **ctx.kwargs
)
else:
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册