functional.py 2.5 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""Pyreader based Dataset"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28

import sys
import numpy as np
import logging

import paddle.fluid as F
import paddle.fluid.layers as L

from propeller.data.functional import Dataset as DatasetBase

log = logging.getLogger(__name__)


class Dataset(DatasetBase):
C
chenxuyi 已提交
29 30
    """Pyreader based Dataset"""

C
chenxuyi 已提交
31
    def placeholders(self):
C
chenxuyi 已提交
32
        """doc"""
C
chenxuyi 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        if self.name is None:
            raise ValueError('can not get feature from unnamed Dataset')

        ret = []
        for i, (shape,
                types) in enumerate(zip(self.data_shapes, self.data_types)):
            ret.append(
                L.data(
                    '%s_placeholder_%d' % (self.name, i),
                    shape=shape,
                    append_batch_size=False,
                    dtype=types))
        return ret

    def features(self):
C
chenxuyi 已提交
48
        """start point of net building. call this in a program scope"""
C
chenxuyi 已提交
49 50 51 52 53 54 55 56 57
        if self.name is None:
            raise ValueError('can not get feature from unnamed Dataset')

        if len(self.data_shapes) != len(self.data_types):
            raise ValueError(
                'Dataset shapes and types not match: shape:%s types%s' %
                (repr(self._data_shapes), repr(self._data_types)))
        return self.placeholders()

C
chenxuyi 已提交
58 59 60 61 62
    def start(self, places=None):
        """start Pyreader"""
        if places is None:
            places = F.cuda_places() if F.core.is_compiled_with_cuda(
            ) else F.cpu_places()
C
chenxuyi 已提交
63
        #assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
C
chenxuyi 已提交
64
        def _gen():
C
chenxuyi 已提交
65 66 67 68 69 70 71 72 73
            try:
                for idx, i in enumerate(self.generator()):
                    yield i
            except Exception as e:
                log.exception(e)
                raise e

        r = F.io.PyReader(
            feed_list=self.placeholders(), capacity=50, iterable=True)
C
chenxuyi 已提交
74
        r.decorate_batch_generator(_gen, places=places)
C
chenxuyi 已提交
75
        return r()