提交 59be4ec2 编写于 作者: L LielinJiang

Env->ParallelEnv

上级 093bfb0c
......@@ -22,7 +22,7 @@ import numpy as np
from paddle import fluid
from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import Env, ParallelStrategy
from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
from paddle.fluid.io import BatchSampler
......@@ -64,8 +64,8 @@ class DistributedBatchSampler(BatchSampler):
"drop_last should be a boolean number"
self.drop_last = drop_last
self.nranks = Env().nranks
self.local_rank = Env().local_rank
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
......@@ -164,14 +164,14 @@ def init_communicator(program, rank, nranks, wait_port,
def prepare_distributed_context(place=None):
if place is None:
place = fluid.CUDAPlace(Env().dev_id) if Env().nranks > 1 \
place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \
else fluid.CUDAPlace(0)
strategy = ParallelStrategy()
strategy.nranks = Env().nranks
strategy.local_rank = Env().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint
strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册