未验证 提交 f55c0b33 编写于 作者: S ShenLiang 提交者: GitHub

[Bug]Fix recompute random in modelparallel (#42747)

* fix recompute in mp

* fix recompute
上级 8eecd852
...@@ -19,7 +19,7 @@ from paddle.fluid import core ...@@ -19,7 +19,7 @@ from paddle.fluid import core
from paddle import _C_ops from paddle import _C_ops
from paddle.autograd import PyLayer, EagerPyLayer from paddle.autograd import PyLayer, EagerPyLayer
from paddle.fluid import framework from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
...@@ -151,20 +151,6 @@ def _merge_activation(tensor): ...@@ -151,20 +151,6 @@ def _merge_activation(tensor):
return _all_gather(tensor, group=mp_group) return _all_gather(tensor, group=mp_group)
@contextlib.contextmanager
def _swith_rng_state_tracker(rng_state, tracker):
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class _HPEagerRecomputeFunction(EagerPyLayer): class _HPEagerRecomputeFunction(EagerPyLayer):
""" """
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
...@@ -261,8 +247,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer): ...@@ -261,8 +247,8 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
tracer._has_grad = True tracer._has_grad = True
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state, with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker): ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
...@@ -393,8 +379,8 @@ class _HPRecomputeFunction(PyLayer): ...@@ -393,8 +379,8 @@ class _HPRecomputeFunction(PyLayer):
tracer._has_grad = True tracer._has_grad = True
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state, with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker): ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
......
...@@ -53,18 +53,24 @@ def check_recompute_necessary(inputs): ...@@ -53,18 +53,24 @@ def check_recompute_necessary(inputs):
@contextlib.contextmanager @contextlib.contextmanager
def swith_rng_state(rng_state): def swith_rng_state_tracker(rng_state, tracker):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
orig_cuda_rng_state = paddle.get_cuda_rng_state() orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state) paddle.set_cuda_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try: try:
yield yield
finally: finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state) paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class EagerRecomputeFunction(EagerPyLayer): class EagerRecomputeFunction(EagerPyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad: if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args) check_recompute_necessary(args)
...@@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer):
"Recompute with RNG perserve is not support current device: {}.". "Recompute with RNG perserve is not support current device: {}.".
format(cur_device)) format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not # TODO need to check the recompute calling is vaild or not
...@@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer):
# NOTE support AMP # NOTE support AMP
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state: if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state): with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
...@@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer): ...@@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer):
class RecomputeFunction(PyLayer): class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
if framework._dygraph_tracer()._has_grad: if framework._dygraph_tracer()._has_grad:
check_recompute_necessary(args) check_recompute_necessary(args)
...@@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer): ...@@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer):
"Recompute with RNG perserve is not support current device: {}.". "Recompute with RNG perserve is not support current device: {}.".
format(cur_device)) format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer): ...@@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer):
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not # TODO need to check the recompute calling is vaild or not
...@@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer): ...@@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer):
# NOTE support AMP # NOTE support AMP
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state: if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state): with swith_rng_state_tracker(ctx.fw_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册