未验证 提交 1e39b1ec 编写于 作者: H Haohongxiang 提交者: GitHub

[incubate/new_frl] Support detach of EagerParamBase in recompute (#56767)

上级 5a9214d8
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import copy
import weakref import weakref
import paddle import paddle
...@@ -21,6 +22,7 @@ from paddle.autograd import PyLayer ...@@ -21,6 +22,7 @@ from paddle.autograd import PyLayer
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker, get_rng_state_tracker,
) )
from paddle.fluid.framework import EagerParamBase
from paddle.framework import core, in_dygraph_mode from paddle.framework import core, in_dygraph_mode
from ..utils.log_util import logger from ..utils.log_util import logger
...@@ -28,6 +30,17 @@ from ..utils.log_util import logger ...@@ -28,6 +30,17 @@ from ..utils.log_util import logger
__all__ = [] __all__ = []
def _varbase_help(param):
state = copy.deepcopy(param.__dict__)
new_param = EagerParamBase(
shape=param.shape, dtype=param.dtype, name=param.name, **state
)
param._share_buffer_to(new_param)
return new_param
def detach_variable(inputs): def detach_variable(inputs):
out = [] out = []
for inp in inputs: for inp in inputs:
...@@ -35,6 +48,10 @@ def detach_variable(inputs): ...@@ -35,6 +48,10 @@ def detach_variable(inputs):
out.append(inp) out.append(inp)
continue continue
if isinstance(inp, EagerParamBase):
out.append(_varbase_help(inp))
continue
x = inp.detach() x = inp.detach()
x.stop_gradient = inp.stop_gradient x.stop_gradient = inp.stop_gradient
out.append(x) out.append(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册