提交 a2354d8b 编写于 作者: W wangjiawei04

add inmemory dataset

上级 b4063f51
...@@ -26,6 +26,7 @@ from paddle.fluid.contrib.utils.hdfs_utils import HDFSClient ...@@ -26,6 +26,7 @@ from paddle.fluid.contrib.utils.hdfs_utils import HDFSClient
__all__ = ["DatasetBase", "DataLoader", "QueueDataset", "InMemoryDataset"] __all__ = ["DatasetBase", "DataLoader", "QueueDataset", "InMemoryDataset"]
class DatasetBase(object): class DatasetBase(object):
"""R """R
""" """
...@@ -152,6 +153,7 @@ class QueueDataset(DatasetBase): ...@@ -152,6 +153,7 @@ class QueueDataset(DatasetBase):
break break
return dataset return dataset
class InMemoryDataset(QueueDataset): class InMemoryDataset(QueueDataset):
def _get_dataset(self, dataset_name, context): def _get_dataset(self, dataset_name, context):
with open("context.txt", "w+") as fout: with open("context.txt", "w+") as fout:
...@@ -197,7 +199,10 @@ class InMemoryDataset(QueueDataset): ...@@ -197,7 +199,10 @@ class InMemoryDataset(QueueDataset):
"hadoop.job.ugi": hdfs_ugi "hadoop.job.ugi": hdfs_ugi
} }
hdfs_client = HDFSClient(hadoop_home, hdfs_configs) hdfs_client = HDFSClient(hadoop_home, hdfs_configs)
file_list = ["{}/{}".format(hdfs_addr, x) for x in hdfs_client.lsr(train_data_path)] file_list = [
"{}/{}".format(hdfs_addr, x)
for x in hdfs_client.lsr(train_data_path)
]
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
file_list = split_files(file_list, context["fleet"].worker_index(), file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册