提交 d52586a9 编写于 作者: X xjqbest 提交者: dongdaxiang

add doc string

test=develop
上级 6be9f719
......@@ -80,6 +80,11 @@ class AsyncExecutor(object):
def __init__(self, place=None, run_mode=""):
"""
Init.
Example:
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
Args:
place(Place): CPUPlace or GPUPlace.
run_mode(str): default is empty string.
......@@ -99,6 +104,14 @@ class AsyncExecutor(object):
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
"""
Run program by this AsyncExecutor.
Example:
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
>>> async_executor.run(default_main_program(),
my_data_feed_desc,
["a.txt", "b.txt"])
Args:
program(Program): the program that need to run, if not provied,
then default_main_program will be used.
......@@ -235,12 +248,13 @@ class AsyncExecutor(object):
>>> exe.download_data("/xxx/xxx/xx/",
>>> "./data", "afs://
>>> xxx.xxx.xxx.xxx:9901", "xxx,yyy")
Args:
afs_path(str): afs_path defined by users
local_path(str): download data path
fs_default_name(str): file system server address
ugi(str): hadoop ugi
file_cn(int): a user can specify file number for debugging
file_cnt(int): a user can specify file number for debugging
hadoop_home(str): hadoop home path
process_num(int): download process num
"""
......@@ -298,7 +312,8 @@ class AsyncExecutor(object):
def init_server(self, dist_desc):
"""
initialize server of current node if current process is a server
Initialize server of current node if current process is a server.
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
......@@ -319,7 +334,8 @@ class AsyncExecutor(object):
def init_worker(self, dist_desc, startup_program):
"""
initialize worker of current node if current process is a worker
Initialize worker of current node if current process is a worker.
Args:
dist_desc(str): a protobuf string that describes
how to init a worker and a server
......@@ -364,7 +380,8 @@ class AsyncExecutor(object):
def save_model(self, save_path):
"""
save_model command that can be invoked from one of the worker
model parameters are saved in servers and upload to save_path of file system
model parameters are saved in servers and upload to save_path of file system.
Args:
save_path(str): save path to file system
"""
......
......@@ -17,32 +17,83 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object):
"""
DeviceWorker is a abstract class, which generates worker desc.
"""
def __init__(self):
"""
Init.
"""
self.program_ = None
def set_fleet_desc(self, fleet_desc):
"""
Set fleet desc.
Args:
fleet_desc(PSParameter): pslib.PSParameter object
"""
self.fleet_desc_ = fleet_desc
def set_program(self, program):
"""
Set program.
Args:
program(Program): a Program object
"""
self.program_ = program
def gen_worker_desc(self, trainer_desc):
pass
"""
Generator worker desc.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
raise NotImplementedError(
"DeviceWorker does not implement gen_worker_desc, "
"please use Hogwild or DownpourSGD, etc.")
class Hogwild(DeviceWorker):
"""
Hogwild is a kind of SGD algorithm.
"""
def __init__(self):
"""
Init.
"""
super(Hogwild, self).__init__()
def gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is HogwildWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
trainer_desc.device_worker_name = "HogwildWorker"
class DownpourSGD(DeviceWorker):
"""
DownpourSGD is a kind of distributed SGD algorithm.
"""
def __init__(self):
"""
Init.
"""
super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is DownpourWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
dense_table_set = set()
program_id = str(id(self.program_))
if self.program_ == None:
......
......@@ -127,6 +127,10 @@ class Fleet(object):
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
worker with pserver.
Args:
programs(Program|list): a Program or a list of Programs
"""
if not isinstance(programs, list):
programs = [programs]
......
......@@ -21,7 +21,13 @@ import unittest
class TestDataset(unittest.TestCase):
"""
TestCases for Dataset.
"""
def test_dataset_create(self):
"""
Testcase for dataset create
"""
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
......@@ -39,6 +45,9 @@ class TestDataset(unittest.TestCase):
self.assertTrue(True)
def test_dataset_config(self):
"""
Testcase for dataset configuration
"""
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
......@@ -62,12 +71,15 @@ class TestDataset(unittest.TestCase):
self.assertEqual(ugi, "my_fs_ugi")
def test_in_memory_dataset_run(self):
with open("test_dataset_a.txt", "w") as f:
"""
Testcase for InMemoryDataset from create to run
"""
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_dataset_b.txt", "w") as f:
with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
......@@ -84,7 +96,8 @@ class TestDataset(unittest.TestCase):
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"])
dataset.set_filelist(["test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
......@@ -98,16 +111,19 @@ class TestDataset(unittest.TestCase):
except:
self.assertTrue(False)
os.remove("./test_dataset_a.txt")
os.remove("./test_dataset_b.txt")
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
def test_queue_dataset_run(self):
with open("test_dataset_a.txt", "w") as f:
"""
Testcase for QueueDataset from create to run
"""
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_dataset_b.txt", "w") as f:
with open("test_queue_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
......@@ -124,7 +140,8 @@ class TestDataset(unittest.TestCase):
dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(["test_dataset_a.txt", "test_dataset_b.txt"])
dataset.set_filelist(["test_queue_dataset_run_a.txt",
"test_queue_dataset_run_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
......@@ -136,8 +153,8 @@ class TestDataset(unittest.TestCase):
except:
self.assertTrue(False)
os.remove("./test_dataset_a.txt")
os.remove("./test_dataset_b.txt")
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册