未验证 提交 8db15a42 编写于 作者: G Ghost Screaming 提交者: GitHub

Fix hybrid parallel training strategy using bf16 (#51103)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Remove climits.

* Fix bug of hybrid parallel strategy with recompute using bf16.

* Fix bug of recompute_hybrid ctx.amp_dtype

* Fix bug of amp_dtype.

* Fix bug of auto_cast.
上级 c191b707
...@@ -353,8 +353,11 @@ def amp_guard( ...@@ -353,8 +353,11 @@ def amp_guard(
# check amp_dtype: float16 or bfloat16 # check amp_dtype: float16 or bfloat16
dtype = dtype.lower() dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']): if enable:
raise ValueError("dtype should be 'float16' or 'bfloat16'.") if not (dtype in ['float16', 'bfloat16']):
raise ValueError(
"If enable amp, dtype should be 'float16' or 'bfloat16'."
)
# check tracer # check tracer
tracer = _dygraph_tracer() tracer = _dygraph_tracer()
......
...@@ -21,6 +21,7 @@ FLOAT_TYPE_DICT = { ...@@ -21,6 +21,7 @@ FLOAT_TYPE_DICT = {
paddle.float16: "float16", paddle.float16: "float16",
paddle.float32: "float32", paddle.float32: "float32",
paddle.float64: "float64", paddle.float64: "float64",
paddle.bfloat16: "bfloat16",
} }
PADDLE_TO_NUMBER = { PADDLE_TO_NUMBER = {
...@@ -29,6 +30,7 @@ PADDLE_TO_NUMBER = { ...@@ -29,6 +30,7 @@ PADDLE_TO_NUMBER = {
paddle.float64: 2, paddle.float64: 2,
paddle.int32: 3, paddle.int32: 3,
paddle.int64: 4, paddle.int64: 4,
paddle.bfloat16: 5,
} }
NUMBER_TO_DTYPE = { NUMBER_TO_DTYPE = {
...@@ -37,6 +39,7 @@ NUMBER_TO_DTYPE = { ...@@ -37,6 +39,7 @@ NUMBER_TO_DTYPE = {
2: "float64", 2: "float64",
3: "int32", 3: "int32",
4: "int64", 4: "int64",
5: "bfloat16",
} }
......
...@@ -128,6 +128,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -128,6 +128,7 @@ class _HPRecomputeFunction(PyLayer):
raise ValueError( raise ValueError(
"unsupported amp level: {}".format(tracer._amp_level) "unsupported amp level: {}".format(tracer._amp_level)
) )
ctx.amp_dtype = tracer._amp_dtype
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad(): with paddle.no_grad():
...@@ -203,12 +204,19 @@ class _HPRecomputeFunction(PyLayer): ...@@ -203,12 +204,19 @@ class _HPRecomputeFunction(PyLayer):
with swith_rng_state_tracker( with swith_rng_state_tracker(
ctx.fwd_rng_state, ctx.fwd_rng_state_tracker ctx.fwd_rng_state, ctx.fwd_rng_state_tracker
): ):
with paddle.amp.auto_cast( if ctx.is_fw_autocast:
enable=ctx.is_fw_autocast, with paddle.amp.auto_cast(
custom_white_list=ctx.amp_white_list, enable=ctx.is_fw_autocast,
custom_black_list=ctx.amp_black_list, custom_white_list=ctx.amp_white_list,
level=ctx.amp_level, custom_black_list=ctx.amp_black_list,
): level=ctx.amp_level,
dtype=ctx.amp_dtype,
):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(
*detached_inputs, **ctx.kwargs
)
else:
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册