diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 728080a7cd248ec72fa4e88393a08b231a628615..fc1fc4f992e36ea5b81d245dc965741fde08455c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -17,10 +17,9 @@ import contextlib import paddle from paddle.fluid import core from paddle import _C_ops -import paddle.distributed as dist from paddle.autograd import PyLayer from paddle.fluid import framework -from paddle.distributed.fleet.utils.recompute import check_recompute_necessary, detach_variable +from ...utils.recompute import check_recompute_necessary, detach_variable from ..parallel_layers.random import get_rng_state_tracker __all__ = [] @@ -239,7 +238,7 @@ class _HPRecomputeFunction(PyLayer): tensor_shapes = ctx.tensor_shapes tensors = list(ctx.saved_tensor()) - device_id = dist.ParallelEnv().device_id + device_id = paddle.distributed.ParallelEnv().device_id for i, idx in enumerate(tensor_indices): if _recompute_partition: state = tensors[i].stop_gradient