data_sources.py 7.5 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Sources are helpers to define paddle training data or testing data.
"""
from paddle.trainer.config_parser import *
from .utils import deprecated

try:
    import cPickle as pickle
except ImportError:
    import pickle

L
Luo Tao 已提交
25
__all__ = ['define_py_data_sources2']
Z
zhangjinchao01 已提交
26 27


Q
qijun 已提交
28 29 30 31 32 33
def define_py_data_source(file_list,
                          cls,
                          module,
                          obj,
                          args=None,
                          async=False,
Z
zhangjinchao01 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
                          data_cls=PyData):
    """
    Define a python data source.

    For example, the simplest usage in trainer_config.py as follow:

    ..  code-block:: python

        define_py_data_source("train.list", TrainData, "data_provider", "process")

    Or. if you want to pass arguments from trainer_config to data_provider.py, then

    ..  code-block:: python

        define_py_data_source("train.list", TrainData, "data_provider", "process",
                              args={"dictionary": dict_name})

    :param data_cls:
L
Luo Tao 已提交
52
    :param file_list: file list name, which contains all data file paths
Z
zhangjinchao01 已提交
53 54 55 56 57 58 59 60
    :type file_list: basestring
    :param cls: Train or Test Class.
    :type cls: TrainData or TestData
    :param module: python module name.
    :type module: basestring
    :param obj: python object name. May be a function name if using
                PyDataProviderWrapper.
    :type obj: basestring
61 62
    :param args: The best practice is using dict to pass arguments into
                 DataProvider, and use :code:`@init_hook_wrapper` to
Z
zhangjinchao01 已提交
63 64 65 66 67 68 69 70 71
                 receive arguments.
    :type args: string or picklable object
    :param async: Load Data asynchronously or not.
    :type async: bool
    :return: None
    :rtype: None
    """
    if isinstance(file_list, list):
        file_list_name = 'train.list'
Y
Yu Yang 已提交
72
        if cls == TestData:
Z
zhangjinchao01 已提交
73
            file_list_name = 'test.list'
74
        with open(file_list_name, 'w') as f:
Z
zhangjinchao01 已提交
75 76 77 78 79 80
            f.writelines(file_list)
        file_list = file_list_name

    if not isinstance(args, basestring) and args is not None:
        args = pickle.dumps(args, 0)

Q
qijun 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    cls(
        data_cls(
            files=file_list,
            load_data_module=module,
            load_data_object=obj,
            load_data_args=args,
            async_load_data=async))


def define_py_data_sources(train_list,
                           test_list,
                           module,
                           obj,
                           args=None,
                           train_async=False,
                           data_cls=PyData):
Z
zhangjinchao01 已提交
97
    """
L
Luo Tao 已提交
98 99
    The annotation is almost the same as define_py_data_sources2, except that
    it can specific train_async and data_cls.
Z
zhangjinchao01 已提交
100

101
    :param data_cls:
Z
zhangjinchao01 已提交
102 103 104 105 106 107 108 109 110 111 112 113
    :param train_list: Train list name.
    :type train_list: basestring
    :param test_list: Test list name.
    :type test_list: basestring
    :param module: python module name. If train and test is different, then
                   pass a tuple or list to this argument.
    :type module: basestring or tuple or list
    :param obj: python object name. May be a function name if using
                PyDataProviderWrapper. If train and test is different, then pass
                a tuple or list to this argument.
    :type obj: basestring or tuple or list
    :param args: The best practice is using dict() to pass arguments into
114 115
                 DataProvider, and use :code:`@init_hook_wrapper` to receive
                 arguments. If train and test is different, then pass a tuple
Z
zhangjinchao01 已提交
116 117 118 119 120 121 122 123 124
                 or list to this argument.
    :type args: string or picklable object or list or tuple.
    :param train_async: Is training data load asynchronously or not.
    :type train_async: bool
    :return: None
    :rtype: None
    """

    def __is_splitable__(o):
Q
qijun 已提交
125 126
        return (isinstance(o, list) or
                isinstance(o, tuple)) and hasattr(o, '__len__') and len(o) == 2
Z
zhangjinchao01 已提交
127 128 129 130 131 132 133 134 135 136 137 138

    assert train_list is not None or test_list is not None
    assert module is not None and obj is not None

    test_module = module
    train_module = module
    if __is_splitable__(module):
        train_module, test_module = module

    test_obj = obj
    train_obj = obj
    if __is_splitable__(obj):
139
        train_obj, test_obj = obj
Z
zhangjinchao01 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    if args is None:
        args = ""

    train_args = args
    test_args = args
    if __is_splitable__(args):
        train_args, test_args = args

    if train_list is not None:
        define_py_data_source(train_list, TrainData, train_module, train_obj,
                              train_args, train_async, data_cls)

    if test_list is not None:
        define_py_data_source(test_list, TestData, test_module, test_obj,
                              test_args, False, data_cls)


def define_py_data_sources2(train_list, test_list, module, obj, args=None):
L
Luo Tao 已提交
159 160 161 162 163 164 165
    """
    Define python Train/Test data sources in one method. If train/test use
    the same Data Provider configuration, module/obj/args contain one argument,
    otherwise contain a list or tuple of arguments. For example\:

    ..  code-block:: python

166 167
        define_py_data_sources2(train_list="train.list",
                                test_list="test.list",
L
Luo Tao 已提交
168 169 170
                                module="data_provider"
                                # if train/test use different configurations,
                                # obj=["process_train", "process_test"]
171
                                obj="process",
L
Luo Tao 已提交
172 173
                                args={"dictionary": dict_name})

L
Luo Tao 已提交
174
    The related data provider can refer to :ref:`api_pydataprovider2_sequential_model` .
L
Luo Tao 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187

    :param train_list: Train list name.
    :type train_list: basestring
    :param test_list: Test list name.
    :type test_list: basestring
    :param module: python module name. If train and test is different, then
                   pass a tuple or list to this argument.
    :type module: basestring or tuple or list
    :param obj: python object name. May be a function name if using
                PyDataProviderWrapper. If train and test is different, then pass
                a tuple or list to this argument.
    :type obj: basestring or tuple or list
    :param args: The best practice is using dict() to pass arguments into
188 189
                 DataProvider, and use :code:`@init_hook_wrapper` to receive
                 arguments. If train and test is different, then pass a tuple
L
Luo Tao 已提交
190 191 192 193 194
                 or list to this argument.
    :type args: string or picklable object or list or tuple.
    :return: None
    :rtype: None
    """
W
wangyanfei01 已提交
195

196 197
    def py_data2(files, load_data_module, load_data_object, load_data_args,
                 **kwargs):
198
        data = create_data_config_proto()
199 200 201 202 203 204 205 206
        data.type = 'py2'
        data.files = files
        data.load_data_module = load_data_module
        data.load_data_object = load_data_object
        data.load_data_args = load_data_args
        data.async_load_data = True
        return data

Q
qijun 已提交
207 208 209 210 211 212
    define_py_data_sources(
        train_list=train_list,
        test_list=test_list,
        module=module,
        obj=obj,
        args=args,
213
        data_cls=py_data2)