dataset.py 10.8 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import collections
import io
import math
import os
import warnings

import paddle.distributed as dist
from paddle.io import Dataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME

__all__ = [
    'MapDatasetWrapper',
    'TSVDataset',
]


@classmethod
def get_datasets(cls, *args, **kwargs):
    """
    Get muitiple datasets like train, valid and test of current dataset.

    Example:
        .. code-block:: python
            from paddlenlp.datasets import GlueQNLI
            train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(['train', 'dev', 'test'])
S
smallv0221 已提交
43
            train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(mode=['train', 'dev', 'test'])
Z
Zeyu Chen 已提交
44 45
            train_dataset = GlueQNLI.get_datasets('train')
            train_dataset = GlueQNLI.get_datasets(['train'])
S
smallv0221 已提交
46
            train_dataset = GlueQNLI.get_datasets(mode='train')
Z
Zeyu Chen 已提交
47 48 49
    """
    if not args and not kwargs:
        try:
S
smallv0221 已提交
50
            args = cls.SPLITS.keys()
Z
Zeyu Chen 已提交
51 52
        except:
            raise AttributeError(
S
smallv0221 已提交
53
                'Dataset must have SPLITS attridute to use get_dataset if configs is None.'
Z
Zeyu Chen 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
            )

        datasets = tuple(MapDatasetWrapper(cls(arg)) for arg in args)
    else:

        for arg in args:
            if not isinstance(arg, list):
                return MapDatasetWrapper(cls(*args, **kwargs))
        for value in kwargs.values():
            if not isinstance(value, list):
                return MapDatasetWrapper(cls(*args, **kwargs))

        num_datasets = len(args[0]) if args else len(list(kwargs.values())[0])
        datasets = tuple(
            MapDatasetWrapper(
                cls(*(args[i] for args in args), **(
                    {key: value[i]
                     for key, value in kwargs.items()})))
            for i in range(num_datasets))

    return datasets if len(datasets) > 1 else datasets[0]


Dataset.get_datasets = get_datasets


class MapDatasetWrapper(Dataset):
    """
    Wraps a dataset-like object as a instance of Dataset, and equips it with
    `apply` and other utility methods. All non-magic methods of the raw object
    also accessible.
    Args:
        data (list|Dataset): A dataset-like object. It can be a list or a
            subclass of Dataset.
    """

    def __init__(self, data):
        self.data = data
S
smallv0221 已提交
92
        self._transform_pipline = []
S
smallv0221 已提交
93
        self.new_data = self.data
S
smallv0221 已提交
94

S
smallv0221 已提交
95
    def _transform(self, data, pipline):
S
smallv0221 已提交
96 97 98
        for fn in reversed(pipline):
            data = fn(data)
        return data
Z
Zeyu Chen 已提交
99 100

    def __getitem__(self, idx):
S
smallv0221 已提交
101 102 103
        return self._transform(
            self.new_data[idx], self._transform_pipline
        ) if self._transform_pipline else self.new_data[idx]
Z
Zeyu Chen 已提交
104 105

    def __len__(self):
S
smallv0221 已提交
106
        return len(self.new_data)
Z
Zeyu Chen 已提交
107 108 109 110 111 112 113 114 115 116 117

    def filter(self, fn):
        """
        Filters samples by the filter function and uses the filtered data to
        create a new MapDatasetWrapper instance.
        Args:
            fn (callable): A filter function that takes a sample as input and
                returns a boolean. Samples that return False are discarded.
        Returns:
            MapDatasetWrapper: The filtered dataset
        """
S
smallv0221 已提交
118 119 120 121 122

        self.new_data = [
            self.new_data[idx] for idx in range(len(self.new_data))
            if fn(self.new_data[idx])
        ]
S
smallv0221 已提交
123
        return self
Z
Zeyu Chen 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

    def shard(self, num_shards=None, index=None):
        """
        Use samples whose indices mod `index` equals 0 to create a new
        MapDatasetWrapper instance.
        Args:
            num_shards (int, optional): A integer representing the number of
                data shards. If None, `num_shards` would be number of trainers.
                Default: None
            index (int, optional): A integer representing the index of the
                current shard. If None, index` would be the current trainer rank
                id. Default: None.
        Returns:
            MapDatasetWrapper: The result dataset
        """
        if num_shards is None:
            num_shards = dist.get_world_size()
        if index is None:
            index = dist.get_rank()
S
smallv0221 已提交
143 144

        num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
Z
Zeyu Chen 已提交
145 146
        total_size = num_samples * num_shards
        # add extra samples to make it evenly divisible
S
smallv0221 已提交
147 148 149 150 151 152 153
        self.new_data = [
            self.new_data[idx] for idx in range(len(self.new_data))
            if idx % num_shards == index
        ]
        if len(self.new_data) < num_samples:
            self.new_data.append(self.new_data[index + 1 - num_shards])

S
smallv0221 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        return self

    def apply(self, fn, lazy=False):
        """
        Performs specific function on the dataset to transform every sample.
        Args:
            fn (callable): Transformations to be performed. It receives single
                sample as argument rather than dataset.
            lazy (bool, optional): If True, transformations would be delayed and
                performed on demand. Otherwise, transforms all samples at once
                and return a new MapDatasetWrapper instance. Note that if `fn` is
                stochastic, `lazy` should be True or you will get the same
                result on all epochs. Defalt: False.
        Returns:
            MapDatasetWrapper: A new MapDatasetWrapper instance if `lazy` is True, \
                otherwise bind `fn` as a property to transform on demand.
        """
        if lazy:
            self._transform_pipline.append(fn)
        else:
S
smallv0221 已提交
174 175 176
            self.new_data = [
                fn(self.new_data[idx]) for idx in range(len(self.new_data))
            ]
S
smallv0221 已提交
177 178 179 180
        return self

    def __getattr__(self, name):
        return getattr(self.data, name)
Z
Zeyu Chen 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293


class TSVDataset(Dataset):
    """
    Common tab separated text dataset that reads text fields based on provided
    sample splitter and field separator.
    The returned dataset includes samples, each of which can either be a list
    of text fields if field_separator is specified, or otherwise a single
    string segment produced by the sample_splitter.
    Args:
        filename (str|list of str): Path to the input text file or list of
            paths to the input text files.
        encoding (str): File encoding format. Default: 'utf8'.
        sample_splitter (function): A function that splits the dataset string
            into samples. Default: str.splitlines
        field_separator (function|None): A function that splits each sample
            string into list of text fields. If None, raw samples are returned
            according to `sample_splitter`. Default: split method of str with
            tab as separator.
        num_discard_samples (int): Number of samples discarded at the head of
            the first file. Default: 0.
        field_indices (list|int|None): If set, for each sample, only fields
            with provided indices are selected as the output. Otherwise all
            fields are returned. Default: None.
        allow_missing (bool): If set to True, no exception will be thrown if
            the number of fields is smaller than the maximum field index
            provided.  Default: False.
        
    Example:
        assume `test.tsv` contains the following content:
        Id\tFirstName\tLastName
        a\tmale\tTom
        b\tFemal\tCat
        discard the first line and select the 0th and 2nd fields
        .. code-block:: python
            from paddle.incubate.hapi.text.glue import TSVDataset
            dataset = TSVDataset('test.tsv', num_discard_samples=1,
                                field_indices=[0, 2])
            dataset[0] # ['a', 'Tom']
            dataset[1] # ['b', 'Cat']
    """

    def __init__(self,
                 filename,
                 encoding='utf-8',
                 sample_splitter=lambda x: x.splitlines(),
                 field_separator=lambda x: x.split('\t'),
                 num_discard_samples=0,
                 field_indices=None,
                 allow_missing=False):
        assert sample_splitter, 'sample_splitter must be specified.'

        if not isinstance(filename, (tuple, list)):
            filename = (filename, )

        self._filenames = [os.path.expanduser(f) for f in filename]
        self._encoding = encoding
        self._sample_splitter = sample_splitter
        self._field_separator = field_separator
        self._num_discard_samples = num_discard_samples
        self._field_indices = field_indices
        self._allow_missing = allow_missing
        self.data = self._read()

    def _should_discard(self):
        discard = self._num_discard_samples > 0
        self._num_discard_samples -= 1
        return discard

    def _field_selector(self, fields):
        if not self._field_indices:
            return fields
        try:
            result = [fields[i] for i in self._field_indices]
        except IndexError as e:
            raise (IndexError('%s. Fields = %s' % (str(e), str(fields))))
        return result

    def _read(self):
        all_samples = []
        for filename in self._filenames:
            with io.open(filename, 'r', encoding=self._encoding) as fin:
                content = fin.read()
            samples = (s for s in self._sample_splitter(content)
                       if not self._should_discard())
            if self._field_separator:
                if not self._allow_missing:
                    samples = [
                        self._field_selector(self._field_separator(s))
                        for s in samples
                    ]
                else:
                    selected_samples = []
                    num_missing = 0
                    for s in samples:
                        try:
                            fields = self._field_separator(s)
                            selected_samples.append(
                                self._field_selector(fields))
                        except IndexError:
                            num_missing += 1
                    if num_missing > 0:
                        warnings.warn('%d incomplete samples in %s' %
                                      (num_missing, filename))
                    samples = selected_samples
            all_samples += samples
        return all_samples

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)