未验证 提交 507ea06f 编写于 作者: S ShenLiang 提交者: GitHub

[Bug-Fix]fix bug of py36 import utils (#34873)

* fix bug of py36 import
上级 e92f0388
...@@ -17,10 +17,9 @@ import contextlib ...@@ -17,10 +17,9 @@ import contextlib
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops from paddle import _C_ops
import paddle.distributed as dist
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from paddle.fluid import framework from paddle.fluid import framework
from paddle.distributed.fleet.utils.recompute import check_recompute_necessary, detach_variable from ...utils.recompute import check_recompute_necessary, detach_variable
from ..parallel_layers.random import get_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker
__all__ = [] __all__ = []
...@@ -239,7 +238,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -239,7 +238,7 @@ class _HPRecomputeFunction(PyLayer):
tensor_shapes = ctx.tensor_shapes tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor()) tensors = list(ctx.saved_tensor())
device_id = dist.ParallelEnv().device_id device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices): for i, idx in enumerate(tensor_indices):
if _recompute_partition: if _recompute_partition:
state = tensors[i].stop_gradient state = tensors[i].stop_gradient
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册