提交 59be4ec2 编写于 作者: L LielinJiang

Env->ParallelEnv

上级 093bfb0c
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
from paddle import fluid from paddle import fluid
from paddle.fluid.layers import collective from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import Env, ParallelStrategy from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
from paddle.fluid.io import BatchSampler from paddle.fluid.io import BatchSampler
...@@ -64,8 +64,8 @@ class DistributedBatchSampler(BatchSampler): ...@@ -64,8 +64,8 @@ class DistributedBatchSampler(BatchSampler):
"drop_last should be a boolean number" "drop_last should be a boolean number"
self.drop_last = drop_last self.drop_last = drop_last
self.nranks = Env().nranks self.nranks = ParallelEnv().nranks
self.local_rank = Env().local_rank self.local_rank = ParallelEnv().local_rank
self.epoch = 0 self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks self.total_size = self.num_samples * self.nranks
...@@ -164,14 +164,14 @@ def init_communicator(program, rank, nranks, wait_port, ...@@ -164,14 +164,14 @@ def init_communicator(program, rank, nranks, wait_port,
def prepare_distributed_context(place=None): def prepare_distributed_context(place=None):
if place is None: if place is None:
place = fluid.CUDAPlace(Env().dev_id) if Env().nranks > 1 \ place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \
else fluid.CUDAPlace(0) else fluid.CUDAPlace(0)
strategy = ParallelStrategy() strategy = ParallelStrategy()
strategy.nranks = Env().nranks strategy.nranks = ParallelEnv().nranks
strategy.local_rank = Env().local_rank strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2: if strategy.nranks < 2:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册