提交 d52586a9 编写于 作者: X xjqbest 提交者: dongdaxiang

add doc string

test=develop
上级 6be9f719
...@@ -80,6 +80,11 @@ class AsyncExecutor(object): ...@@ -80,6 +80,11 @@ class AsyncExecutor(object):
def __init__(self, place=None, run_mode=""): def __init__(self, place=None, run_mode=""):
""" """
Init. Init.
Example:
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
Args: Args:
place(Place): CPUPlace or GPUPlace. place(Place): CPUPlace or GPUPlace.
run_mode(str): default is empty string. run_mode(str): default is empty string.
...@@ -99,6 +104,14 @@ class AsyncExecutor(object): ...@@ -99,6 +104,14 @@ class AsyncExecutor(object):
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
""" """
Run program by this AsyncExecutor. Run program by this AsyncExecutor.
Example:
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
>>> async_executor.run(default_main_program(),
my_data_feed_desc,
["a.txt", "b.txt"])
Args: Args:
program(Program): the program that need to run, if not provied, program(Program): the program that need to run, if not provied,
then default_main_program will be used. then default_main_program will be used.
...@@ -235,12 +248,13 @@ class AsyncExecutor(object): ...@@ -235,12 +248,13 @@ class AsyncExecutor(object):
>>> exe.download_data("/xxx/xxx/xx/", >>> exe.download_data("/xxx/xxx/xx/",
>>> "./data", "afs:// >>> "./data", "afs://
>>> xxx.xxx.xxx.xxx:9901", "xxx,yyy") >>> xxx.xxx.xxx.xxx:9901", "xxx,yyy")
Args: Args:
afs_path(str): afs_path defined by users afs_path(str): afs_path defined by users
local_path(str): download data path local_path(str): download data path
fs_default_name(str): file system server address fs_default_name(str): file system server address
ugi(str): hadoop ugi ugi(str): hadoop ugi
file_cn(int): a user can specify file number for debugging file_cnt(int): a user can specify file number for debugging
hadoop_home(str): hadoop home path hadoop_home(str): hadoop home path
process_num(int): download process num process_num(int): download process num
""" """
...@@ -298,7 +312,8 @@ class AsyncExecutor(object): ...@@ -298,7 +312,8 @@ class AsyncExecutor(object):
def init_server(self, dist_desc): def init_server(self, dist_desc):
""" """
initialize server of current node if current process is a server Initialize server of current node if current process is a server.
Args: Args:
dist_desc(str): a protobuf string that describes dist_desc(str): a protobuf string that describes
how to init a worker and a server how to init a worker and a server
...@@ -319,7 +334,8 @@ class AsyncExecutor(object): ...@@ -319,7 +334,8 @@ class AsyncExecutor(object):
def init_worker(self, dist_desc, startup_program): def init_worker(self, dist_desc, startup_program):
""" """
initialize worker of current node if current process is a worker Initialize worker of current node if current process is a worker.
Args: Args:
dist_desc(str): a protobuf string that describes dist_desc(str): a protobuf string that describes
how to init a worker and a server how to init a worker and a server
...@@ -364,7 +380,8 @@ class AsyncExecutor(object): ...@@ -364,7 +380,8 @@ class AsyncExecutor(object):
def save_model(self, save_path): def save_model(self, save_path):
""" """
save_model command that can be invoked from one of the worker save_model command that can be invoked from one of the worker
model parameters are saved in servers and upload to save_path of file system model parameters are saved in servers and upload to save_path of file system.
Args: Args:
save_path(str): save path to file system save_path(str): save path to file system
""" """
......
...@@ -17,32 +17,83 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD'] ...@@ -17,32 +17,83 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object): class DeviceWorker(object):
"""
DeviceWorker is a abstract class, which generates worker desc.
"""
def __init__(self): def __init__(self):
"""
Init.
"""
self.program_ = None self.program_ = None
def set_fleet_desc(self, fleet_desc): def set_fleet_desc(self, fleet_desc):
"""
Set fleet desc.
Args:
fleet_desc(PSParameter): pslib.PSParameter object
"""
self.fleet_desc_ = fleet_desc self.fleet_desc_ = fleet_desc
def set_program(self, program): def set_program(self, program):
"""
Set program.
Args:
program(Program): a Program object
"""
self.program_ = program self.program_ = program
def gen_worker_desc(self, trainer_desc): def gen_worker_desc(self, trainer_desc):
pass """
Generator worker desc.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
raise NotImplementedError(
"DeviceWorker does not implement gen_worker_desc, "
"please use Hogwild or DownpourSGD, etc.")
class Hogwild(DeviceWorker): class Hogwild(DeviceWorker):
"""
Hogwild is a kind of SGD algorithm.
"""
def __init__(self): def __init__(self):
"""
Init.
"""
super(Hogwild, self).__init__() super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc): def gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is HogwildWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
class DownpourSGD(DeviceWorker): class DownpourSGD(DeviceWorker):
"""
DownpourSGD is a kind of distributed SGD algorithm.
"""
def __init__(self): def __init__(self):
"""
Init.
"""
super(DownpourSGD, self).__init__() super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc): def gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is DownpourWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
dense_table_set = set() dense_table_set = set()
program_id = str(id(self.program_)) program_id = str(id(self.program_))
if self.program_ == None: if self.program_ == None:
......
...@@ -127,6 +127,10 @@ class Fleet(object): ...@@ -127,6 +127,10 @@ class Fleet(object):
init_worker(): will be called by user. When a user knows current process is_server(), he/she init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect should call init_worker() to initialize global information about worker and connect
worker with pserver. worker with pserver.
Args:
programs(Program|list): a Program or a list of Programs
""" """
if not isinstance(programs, list): if not isinstance(programs, list):
programs = [programs] programs = [programs]
......
...@@ -21,7 +21,13 @@ import unittest ...@@ -21,7 +21,13 @@ import unittest
class TestDataset(unittest.TestCase): class TestDataset(unittest.TestCase):
"""
TestCases for Dataset.
"""
def test_dataset_create(self): def test_dataset_create(self):
"""
Testcase for dataset create
"""
try: try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except: except:
...@@ -39,6 +45,9 @@ class TestDataset(unittest.TestCase): ...@@ -39,6 +45,9 @@ class TestDataset(unittest.TestCase):
self.assertTrue(True) self.assertTrue(True)
def test_dataset_config(self): def test_dataset_config(self):
"""
Testcase for dataset configuration
"""
dataset = fluid.core.Dataset("MultiSlotDataset") dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12) dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
...@@ -62,12 +71,15 @@ class TestDataset(unittest.TestCase): ...@@ -62,12 +71,15 @@ class TestDataset(unittest.TestCase):
self.assertEqual(ugi, "my_fs_ugi") self.assertEqual(ugi, "my_fs_ugi")
def test_in_memory_dataset_run(self): def test_in_memory_dataset_run(self):
with open("test_dataset_a.txt", "w") as f: """
Testcase for InMemoryDataset from create to run
"""
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n" data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data) f.write(data)
with open("test_dataset_b.txt", "w") as f: with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n" data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n" data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n" data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
...@@ -84,7 +96,8 @@ class TestDataset(unittest.TestCase): ...@@ -84,7 +96,8 @@ class TestDataset(unittest.TestCase):
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32) dataset.set_batch_size(32)
dataset.set_thread(3) dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"]) dataset.set_filelist(["test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"])
dataset.set_pipe_command("cat") dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset.set_use_var(slots_vars)
dataset.load_into_memory() dataset.load_into_memory()
...@@ -98,16 +111,19 @@ class TestDataset(unittest.TestCase): ...@@ -98,16 +111,19 @@ class TestDataset(unittest.TestCase):
except: except:
self.assertTrue(False) self.assertTrue(False)
os.remove("./test_dataset_a.txt") os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_dataset_b.txt") os.remove("./test_in_memory_dataset_run_b.txt")
def test_queue_dataset_run(self): def test_queue_dataset_run(self):
with open("test_dataset_a.txt", "w") as f: """
Testcase for QueueDataset from create to run
"""
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n" data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data) f.write(data)
with open("test_dataset_b.txt", "w") as f: with open("test_queue_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n" data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n" data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n" data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
...@@ -124,7 +140,8 @@ class TestDataset(unittest.TestCase): ...@@ -124,7 +140,8 @@ class TestDataset(unittest.TestCase):
dataset = fluid.DatasetFactory().create_dataset("QueueDataset") dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32) dataset.set_batch_size(32)
dataset.set_thread(3) dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"]) dataset.set_filelist(["test_queue_dataset_run_a.txt",
"test_queue_dataset_run_b.txt"])
dataset.set_pipe_command("cat") dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset.set_use_var(slots_vars)
...@@ -136,8 +153,8 @@ class TestDataset(unittest.TestCase): ...@@ -136,8 +153,8 @@ class TestDataset(unittest.TestCase):
except: except:
self.assertTrue(False) self.assertTrue(False)
os.remove("./test_dataset_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_dataset_b.txt") os.remove("./test_queue_dataset_run_b.txt")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册