提交 a38b98cb 编写于 作者: X xjqbest 提交者: dongdaxiang

fix code style & runtime error

test=develop
上级 8e14d8f9
...@@ -235,15 +235,15 @@ class InMemoryDataset(DatasetBase): ...@@ -235,15 +235,15 @@ class InMemoryDataset(DatasetBase):
""" """
trainer_num = 1 trainer_num = 1
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
self.dataset.global_shuffle() self.dataset.global_shuffle()
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
......
...@@ -98,7 +98,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -98,7 +98,7 @@ class MPIRoleMaker(RoleMakerBase):
""" """
all_gather(obj) will call MPI's allgather function all_gather(obj) will call MPI's allgather function
""" """
self.barrier_all() self._barrier_all()
return self.comm_.allgather(obj) return self.comm_.allgather(obj)
def _barrier_all(self): def _barrier_all(self):
...@@ -112,7 +112,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -112,7 +112,7 @@ class MPIRoleMaker(RoleMakerBase):
collect current distributed job's ip list collect current distributed job's ip list
""" """
if self.ips_ == None: if self.ips_ == None:
self.ips_ = self.comm_.allgather(self.get_local_ip()) self.ips_ = self.comm_.allgather(self._get_local_ip())
return self.ips_ return self.ips_
def _finalize(self): def _finalize(self):
...@@ -146,7 +146,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -146,7 +146,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return whether current process is the first worker assigned by role maker return whether current process is the first worker assigned by role maker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self.is_worker() and 0 == self.worker_index() return self._is_worker() and 0 == self._worker_index()
return False return False
def _is_worker(self): def _is_worker(self):
...@@ -170,8 +170,8 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -170,8 +170,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of worker return the current number of worker
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_worker(): if self._is_worker():
return self.get_size() / 2 return self._get_size() / 2
return 0 return 0
def _server_num(self): def _server_num(self):
...@@ -179,8 +179,8 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -179,8 +179,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of server return the current number of server
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_server(): if self._is_server():
return self.get_size() / 2 return self._get_size() / 2
return 0 return 0
def _worker_index(self): def _worker_index(self):
...@@ -204,7 +204,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -204,7 +204,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all workers in current distributed job barrier all workers in current distributed job
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_worker(): if self._is_worker():
self.node_type_comm_.barrier() self.node_type_comm_.barrier()
def _barrier_server(self): def _barrier_server(self):
...@@ -212,7 +212,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -212,7 +212,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all servers in current distributed job barrier all servers in current distributed job
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_server(): if self._is_server():
self.node_type_comm_.barrier() self.node_type_comm_.barrier()
def _generate_role(self): def _generate_role(self):
...@@ -221,10 +221,10 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -221,10 +221,10 @@ class MPISymetricRoleMaker(MPIRoleMaker):
""" """
if not self.role_is_generated_: if not self.role_is_generated_:
# TODO(guru4elephant): only allow to be called once # TODO(guru4elephant): only allow to be called once
self.trainer_endpoints_ = self.get_ips() self.trainer_endpoints_ = self._get_ips()
self.pserver_endpoints_ = self.get_ips() self.pserver_endpoints_ = self._get_ips()
if 0 == self.get_rank() % self.proc_per_node_ % 2: if 0 == self._get_rank() % self.proc_per_node_ % 2:
self.node_type_ = 0 self.node_type_ = 0
else: else:
self.node_type_ = 1 self.node_type_ = 1
......
...@@ -88,7 +88,7 @@ class Fleet(object): ...@@ -88,7 +88,7 @@ class Fleet(object):
stop(): will be called after a user finishes his/her training task. Fleet instance will be stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called. destroyed when stop() is called.
""" """
self.role_maker_.barrier_worker() self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker(): if self.role_maker_._is_first_worker():
self._fleet_ptr.stop_server() self._fleet_ptr.stop_server()
self.role_maker_._barrier_worker() self.role_maker_._barrier_worker()
......
"""
dataset testcases
"""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -21,13 +25,9 @@ import unittest ...@@ -21,13 +25,9 @@ import unittest
class TestDataset(unittest.TestCase): class TestDataset(unittest.TestCase):
""" """ TestCases for Dataset. """
TestCases for Dataset.
"""
def test_dataset_create(self): def test_dataset_create(self):
""" """ Testcase for dataset create """
Testcase for dataset create
"""
try: try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except: except:
...@@ -45,9 +45,7 @@ class TestDataset(unittest.TestCase): ...@@ -45,9 +45,7 @@ class TestDataset(unittest.TestCase):
self.assertTrue(True) self.assertTrue(True)
def test_dataset_config(self): def test_dataset_config(self):
""" """ Testcase for dataset configuration """
Testcase for dataset configuration
"""
dataset = fluid.core.Dataset("MultiSlotDataset") dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12) dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册