未验证 提交 86f05911 编写于 作者: G gongweibao 提交者: GitHub

Remove node_num function. (#19167)

node_num is not needed for users, so remove them and fix the bugs about it!
上级 c27b0813
...@@ -183,6 +183,9 @@ def start_procs(args): ...@@ -183,6 +183,9 @@ def start_procs(args):
"PADDLE_TRAINER_ENDPOINTS": trainers_endpoints "PADDLE_TRAINER_ENDPOINTS": trainers_endpoints
}) })
if num_nodes > 1:
current_env.update({"FLAGS_sync_nccl_allreduce": "0"})
cmd = [sys.executable, "-u", args.training_script cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args ] + args.training_script_args
......
...@@ -232,14 +232,6 @@ class Fleet(object): ...@@ -232,14 +232,6 @@ class Fleet(object):
def save_persistables(self, executor, dirname, main_program=None): def save_persistables(self, executor, dirname, main_program=None):
pass pass
@abc.abstractmethod
def node_num(self):
pass
@abc.abstractmethod
def node_id(self):
pass
class DistributedOptimizer(object): class DistributedOptimizer(object):
""" """
......
...@@ -384,27 +384,8 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -384,27 +384,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._worker_endpoints = self._worker_endpoints.split(",") self._worker_endpoints = self._worker_endpoints.split(",")
self._trainers_num = len(self._worker_endpoints) self._trainers_num = len(self._worker_endpoints)
self._node_ips = self._get_node_ips_from_endpoints(
self._worker_endpoints)
self._node_ip = self._current_endpoint.split(":")[0].strip()
self._node_num = len(self._node_ips)
self._node_id = self._node_ips.index(self._node_ip)
self._role_is_generated = True self._role_is_generated = True
def _get_node_ips_from_endpoints(self, endpoints):
ss = set()
ips = []
for ep in endpoints:
ip = ep.split(":")[0].strip()
if ip not in ss:
ss.add(ip)
ips.append(ip)
else:
continue
return ips
def get_pserver_endpoints(self): def get_pserver_endpoints(self):
if not self._role_is_generated: if not self._role_is_generated:
self.generate_role() self.generate_role()
......
...@@ -85,12 +85,6 @@ class Collective(Fleet): ...@@ -85,12 +85,6 @@ class Collective(Fleet):
def save_persistables(self, executor, dirname, main_program=None): def save_persistables(self, executor, dirname, main_program=None):
io.save_persistables(executor, dirname, main_program, None) io.save_persistables(executor, dirname, main_program, None)
def node_num(self):
return self._role_maker._node_num
def node_id(self):
return self._role_maker._node_id
fleet = Collective() fleet = Collective()
...@@ -102,9 +96,6 @@ class DistributedStrategy(fluid.BuildStrategy): ...@@ -102,9 +96,6 @@ class DistributedStrategy(fluid.BuildStrategy):
def __init__(self): def __init__(self):
super(DistributedStrategy, self).__init__() super(DistributedStrategy, self).__init__()
self.fuse_memory_size = -1
self.fuse_layer_size = 1
self.use_local_sgd = False self.use_local_sgd = False
self.use_dist_fc = False self.use_dist_fc = False
...@@ -112,21 +103,9 @@ class DistributedStrategy(fluid.BuildStrategy): ...@@ -112,21 +103,9 @@ class DistributedStrategy(fluid.BuildStrategy):
self.dist_fc_config = None # DistFCConfig self.dist_fc_config = None # DistFCConfig
self.mode = "nccl2" # or collective self.mode = "nccl2" # or collective
self.collective_mode = None # local_sgd or grad_allreduce self.collective_mode = None # local_sgd or grad_allreduce
self.nccl_comm_num = 1
self.nccl_comm_num = 2
self.exec_strategy = fluid.ExecutionStrategy() self.exec_strategy = fluid.ExecutionStrategy()
sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
if sync_allreduce == "0":
self._exec_strategy.num_threads = self.nccl_comm_num + 1
if sef.use_hierarchical_allreduce:
self._exec_strategy.num_threads = 2 * self.nccl_comm_num + 1
if self._exec_strategy.num_threads > 4:
print(
sys.stderr,
"WARNING: if you use use_hierarchical_allreduce or "
"with multi nccl comm, please set FLAGS_sync_nccl_allreduce = 0"
)
class CollectiveOpBasedOptimizer(DistributedOptimizer): class CollectiveOpBasedOptimizer(DistributedOptimizer):
...@@ -215,12 +194,6 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -215,12 +194,6 @@ class CollectiveOptimizer(DistributedOptimizer):
""" """
Transpile the programs to distributed programs. And add the variables. Transpile the programs to distributed programs. And add the variables.
""" """
if self._strategy.fuse_all_reduce_ops:
os.environ[
'FLAGS_fuse_parameter_memory_size'] = self.fuse_memory_size
os.environ[
'FLAGS_fuse_parameter_groups_size'] = self.fuse_layer_size
worker_endpoints = fleet.worker_endpoints() worker_endpoints = fleet.worker_endpoints()
trainer_id = fleet.worker_index() trainer_id = fleet.worker_index()
current_endpoint = fleet.worker_endpoints()[trainer_id] current_endpoint = fleet.worker_endpoints()[trainer_id]
...@@ -249,7 +222,67 @@ class CollectiveOptimizer(DistributedOptimizer): ...@@ -249,7 +222,67 @@ class CollectiveOptimizer(DistributedOptimizer):
program=main_program, program=main_program,
current_endpoint=current_endpoint) current_endpoint=current_endpoint)
def _get_node_ips_from_endpoints(self, endpoints):
ss = set()
ips = []
for ep in endpoints:
ip = ep.split(":")[0].strip()
if ip not in ss:
ss.add(ip)
ips.append(ip)
else:
continue
return ips
def _node_num(self):
worker_endpoints = fleet.worker_endpoints()
current_endpoint = fleet.worker_endpoints()[fleet.worker_index()]
worker_endpoints_env = ','.join(worker_endpoints)
node_ips = self._get_node_ips_from_endpoints(worker_endpoints)
node_ip = current_endpoint.split(":")[0].strip()
node_num = len(node_ips)
return node_num
def _try_to_compile(self, startup_program, main_program): def _try_to_compile(self, startup_program, main_program):
node_num = self._node_num()
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
self._strategy.fuse_all_reduce_ops = True
exec_strategy = self._strategy.exec_strategy
if node_num <= 1:
if self._strategy.nccl_comm_num > 1:
logging.warn("set nccl_comm_num=1 since you only have 1 node.")
self._strategy.nccl_comm_num = 1
if self._strategy.use_hierarchical_allreduce:
logging.warn(
"set use_hierarchical_allreduce=False since you only have 1 node."
)
self._strategy.use_hierarchical_allreduce = False
sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
if sync_allreduce is None or sync_allreduce == "1":
exec_strategy.num_threads = self._strategy.nccl_comm_num + 1
if self._strategy.use_hierarchical_allreduce:
exec_strategy.num_threads = 2 * self._strategy.nccl_comm_num + 1
if exec_strategy.num_threads > 4:
logging.warn(
"if you use use_hierarchical_allreduce or "
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
)
if self.print_config:
print("node_num:", node_num, "num_threads:",
exec_strategy.num_threads, "use_hierarchical_allreduce:",
self._strategy.use_hierarchical_allreduce, "nccl_comm_num:",
self._strategy.nccl_comm_num, "FLAGS_sync_nccl_allreduce:",
sync_allreduce)
self._transpile(startup_program, main_program) self._transpile(startup_program, main_program)
if self._strategy.mode == "collective": if self._strategy.mode == "collective":
......
...@@ -239,14 +239,6 @@ class DistributedTranspiler(Fleet): ...@@ -239,14 +239,6 @@ class DistributedTranspiler(Fleet):
self.main_program, self.startup_program = \ self.main_program, self.startup_program = \
self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()]) self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()])
def node_num(self):
logging.warn(
"You should not call 'node_num' method for collective mode.")
def node_id(self):
logging.warn(
"You should not call 'node_id' method for collective mode.")
fleet = DistributedTranspiler() fleet = DistributedTranspiler()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册