未验证 提交 be61e9ea 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #16597 from guru4elephant/refine_dataset

refine dataset API
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import core from . import core
__all__ = ['DatasetFactory'] __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object): class DatasetFactory(object):
...@@ -38,6 +38,10 @@ class DatasetFactory(object): ...@@ -38,6 +38,10 @@ class DatasetFactory(object):
""" """
Create "QueueDataset" or "InMemoryDataset", Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset". the default is "QueueDataset".
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
""" """
try: try:
dataset = globals()[datafeed_class]() dataset = globals()[datafeed_class]()
...@@ -177,7 +181,8 @@ class DatasetBase(object): ...@@ -177,7 +181,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
InMemoryDataset, it will load data into memory InMemoryDataset, it will load data into memory
and shuffle data before training and shuffle data before training.
This class should be created by DatasetFactory
Example: Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset") dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
...@@ -259,7 +264,8 @@ class QueueDataset(DatasetBase): ...@@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
def __init__(self): def __init__(self):
""" """
Init Initialize QueueDataset
This class should be created by DatasetFactory
""" """
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
...@@ -268,7 +274,8 @@ class QueueDataset(DatasetBase): ...@@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
""" """
Local shuffle Local shuffle
QueueDataset does not support local shuffle Local shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support local shuffle, " "QueueDataset does not support local shuffle, "
...@@ -276,7 +283,8 @@ class QueueDataset(DatasetBase): ...@@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
def global_shuffle(self, fleet=None): def global_shuffle(self, fleet=None):
""" """
Global shuffle Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support global shuffle, " "QueueDataset does not support global shuffle, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册