From 507ea06f9b66630385ab96dec144631492877353 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Fri, 13 Aug 2021 14:01:51 +0800 Subject: [PATCH] [Bug-Fix]fix bug of py36 import utils (#34873) * fix bug of py36 import --- .../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 728080a7cd2..fc1fc4f992e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -17,10 +17,9 @@ import contextlib import paddle from paddle.fluid import core from paddle import _C_ops -import paddle.distributed as dist from paddle.autograd import PyLayer from paddle.fluid import framework -from paddle.distributed.fleet.utils.recompute import check_recompute_necessary, detach_variable +from ...utils.recompute import check_recompute_necessary, detach_variable from ..parallel_layers.random import get_rng_state_tracker __all__ = [] @@ -239,7 +238,7 @@ class _HPRecomputeFunction(PyLayer): tensor_shapes = ctx.tensor_shapes tensors = list(ctx.saved_tensor()) - device_id = dist.ParallelEnv().device_id + device_id = paddle.distributed.ParallelEnv().device_id for i, idx in enumerate(tensor_indices): if _recompute_partition: state = tensors[i].stop_gradient -- GitLab