提交 a38b98cb 编写于 作者: X xjqbest 提交者: dongdaxiang

fix code style & runtime error

test=develop
上级 8e14d8f9
......@@ -235,15 +235,15 @@ class InMemoryDataset(DatasetBase):
"""
trainer_num = 1
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
fleet.fleet_instance.role_maker_._barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
fleet.fleet_instance.role_maker_._barrier_worker()
self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
fleet.fleet_instance.role_maker_._barrier_worker()
class QueueDataset(DatasetBase):
......
......@@ -98,7 +98,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
all_gather(obj) will call MPI's allgather function
"""
self.barrier_all()
self._barrier_all()
return self.comm_.allgather(obj)
def _barrier_all(self):
......@@ -112,7 +112,7 @@ class MPIRoleMaker(RoleMakerBase):
collect current distributed job's ip list
"""
if self.ips_ == None:
self.ips_ = self.comm_.allgather(self.get_local_ip())
self.ips_ = self.comm_.allgather(self._get_local_ip())
return self.ips_
def _finalize(self):
......@@ -146,7 +146,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return whether current process is the first worker assigned by role maker
"""
if self._check_role_generation():
return self.is_worker() and 0 == self.worker_index()
return self._is_worker() and 0 == self._worker_index()
return False
def _is_worker(self):
......@@ -170,8 +170,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of worker
"""
if self._check_role_generation():
if self.is_worker():
return self.get_size() / 2
if self._is_worker():
return self._get_size() / 2
return 0
def _server_num(self):
......@@ -179,8 +179,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of server
"""
if self._check_role_generation():
if self.is_server():
return self.get_size() / 2
if self._is_server():
return self._get_size() / 2
return 0
def _worker_index(self):
......@@ -204,7 +204,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all workers in current distributed job
"""
if self._check_role_generation():
if self.is_worker():
if self._is_worker():
self.node_type_comm_.barrier()
def _barrier_server(self):
......@@ -212,7 +212,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all servers in current distributed job
"""
if self._check_role_generation():
if self.is_server():
if self._is_server():
self.node_type_comm_.barrier()
def _generate_role(self):
......@@ -221,10 +221,10 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if not self.role_is_generated_:
# TODO(guru4elephant): only allow to be called once
self.trainer_endpoints_ = self.get_ips()
self.pserver_endpoints_ = self.get_ips()
self.trainer_endpoints_ = self._get_ips()
self.pserver_endpoints_ = self._get_ips()
if 0 == self.get_rank() % self.proc_per_node_ % 2:
if 0 == self._get_rank() % self.proc_per_node_ % 2:
self.node_type_ = 0
else:
self.node_type_ = 1
......
......@@ -88,7 +88,7 @@ class Fleet(object):
stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called.
"""
self.role_maker_.barrier_worker()
self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker():
self._fleet_ptr.stop_server()
self.role_maker_._barrier_worker()
......
"""
dataset testcases
"""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -21,13 +25,9 @@ import unittest
class TestDataset(unittest.TestCase):
"""
TestCases for Dataset.
"""
""" TestCases for Dataset. """
def test_dataset_create(self):
"""
Testcase for dataset create
"""
""" Testcase for dataset create """
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
......@@ -45,9 +45,7 @@ class TestDataset(unittest.TestCase):
self.assertTrue(True)
def test_dataset_config(self):
"""
Testcase for dataset configuration
"""
""" Testcase for dataset configuration """
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册