未验证 提交 1e39b1ec 编写于 作者: H Haohongxiang 提交者: GitHub

[incubate/new_frl] Support detach of EagerParamBase in recompute (#56767)

上级 5a9214d8
......@@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import copy
import weakref
import paddle
......@@ -21,6 +22,7 @@ from paddle.autograd import PyLayer
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
from paddle.fluid.framework import EagerParamBase
from paddle.framework import core, in_dygraph_mode
from ..utils.log_util import logger
......@@ -28,6 +30,17 @@ from ..utils.log_util import logger
__all__ = []
def _varbase_help(param):
state = copy.deepcopy(param.__dict__)
new_param = EagerParamBase(
shape=param.shape, dtype=param.dtype, name=param.name, **state
)
param._share_buffer_to(new_param)
return new_param
def detach_variable(inputs):
out = []
for inp in inputs:
......@@ -35,6 +48,10 @@ def detach_variable(inputs):
out.append(inp)
continue
if isinstance(inp, EagerParamBase):
out.append(_varbase_help(inp))
continue
x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册