未验证 提交 f55c0b33 编写于 作者: S ShenLiang 提交者: GitHub

[Bug]Fix recompute random in modelparallel (#42747)

* fix recompute in mp

* fix recompute
上级 8eecd852
......@@ -19,7 +19,7 @@ from paddle.fluid import core
from paddle import _C_ops
from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable
from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode
......@@ -151,20 +151,6 @@ def _merge_activation(tensor):
return _all_gather(tensor, group=mp_group)
@contextlib.contextmanager
def _swith_rng_state_tracker(rng_state, tracker):
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class _HPEagerRecomputeFunction(EagerPyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
......@@ -261,7 +247,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
......@@ -393,7 +379,7 @@ class _HPRecomputeFunction(PyLayer):
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
......
......@@ -53,18 +53,24 @@ def check_recompute_necessary(inputs):
@contextlib.contextmanager
def swith_rng_state(rng_state):
def swith_rng_state_tracker(rng_state, tracker):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class EagerRecomputeFunction(EagerPyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
......@@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer):
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# TODO support AMP
tracer = framework._dygraph_tracer()
......@@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer):
@staticmethod
def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
......@@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer):
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
......@@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer):
class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args)
......@@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer):
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# TODO support AMP
tracer = framework._dygraph_tracer()
......@@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer):
@staticmethod
def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not
......@@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer):
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册