未验证 提交 be61e9ea 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #16597 from guru4elephant/refine_dataset

refine dataset API
......@@ -15,7 +15,7 @@
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
__all__ = ['DatasetFactory']
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object):
......@@ -38,6 +38,10 @@ class DatasetFactory(object):
"""
Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
"""
try:
dataset = globals()[datafeed_class]()
......@@ -177,7 +181,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
"""
InMemoryDataset, it will load data into memory
and shuffle data before training
and shuffle data before training.
This class should be created by DatasetFactory
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
......@@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
def __init__(self):
"""
Init
Initialize QueueDataset
This class should be created by DatasetFactory
"""
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
......@@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
"""
Local shuffle
QueueDataset does not support local shuffle
Local shuffle is not supported in QueueDataset
NotImplementedError will be raised
"""
raise NotImplementedError(
"QueueDataset does not support local shuffle, "
......@@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
def global_shuffle(self, fleet=None):
"""
Global shuffle
Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
"""
raise NotImplementedError(
"QueueDataset does not support global shuffle, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册