reader.py 3.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import core
import six
import threading
from .framework import Program, Variable, program_guard
from .data_feeder import DataFeeder
S
sneaxiy 已提交
20
import paddle.reader.decorator as decorator
S
sneaxiy 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

__all__ = ['PyReader']


def _convert_places(places):
    if not isinstance(places, (list, tuple)):
        places = [places]

    ret = []
    for p in places:
        if not isinstance(p, core.Place):
            tmp = core.Place()
            tmp.set_place(p)
            p = tmp

        ret.append(p)
    return ret


S
sneaxiy 已提交
40 41
class PyReader(Reader):
    def __init__(self, feed_list, places, capacity):
S
sneaxiy 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        self._tensor_reader = None
        self._thread = None

        # TODO(zjl): to support drop_last = False 
        self._drop_last = True

        self._feed_list = feed_list
        self._var_names = [v.name for v in feed_list]

        self._queues = []

        self._places = _convert_places(places)

        self._queue_capacity = capacity

S
sneaxiy 已提交
57 58
        self.queue = core.init_lod_tensor_blocking_queue(core.Variable(),
                                                         self._queue_capacity)
S
sneaxiy 已提交
59

S
sneaxiy 已提交
60
        self._reader = core.create_py_reader(self._queue, self._var_names,
S
sneaxiy 已提交
61 62 63 64 65 66 67 68
                                             self._places, self._drop_last)

    def __call__(self):
        assert self._tensor_reader is not None, \
            "Data source of PyReader has not set yet"

        class Iterator(object):
            def __init__(self, reader):
S
sneaxiy 已提交
69 70
                self._reader = reader._reader
                self._reset = reader._reset
S
sneaxiy 已提交
71 72 73 74 75

            def __iter__(self):
                return self

            def next(self):
S
sneaxiy 已提交
76
                ret = self._reader.read_next()
S
sneaxiy 已提交
77
                if ret:
S
sneaxiy 已提交
78 79
                    return ret
                else:
S
sneaxiy 已提交
80
                    self._reset()
S
sneaxiy 已提交
81 82 83 84
                    raise StopIteration

        return Iterator(self)

S
sneaxiy 已提交
85 86
    def _reset(self):
        if self._thread:
S
sneaxiy 已提交
87
            self._reader.reset()
S
sneaxiy 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100
            self._thread.join()

        def __thread_main__():
            for tensors in self._tensor_reader():
                array = core.LoDTensorArray()
                for item in tensors:
                    if not isinstance(item, core.LoDTensor):
                        tmp = core.LoDTensor()
                        tmp.set(item, core.CPUPlace())
                        item = tmp

                    array.append(item)

S
sneaxiy 已提交
101
                if not self.queue.push(array):
S
sneaxiy 已提交
102 103
                    break

S
sneaxiy 已提交
104
            self.queue.close()
S
sneaxiy 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

        self._thread = threading.Thread(target=__thread_main__)
        self._thread.daemon = True
        self._thread.start()

    def set_numpy_reader(self, reader):
        assert self._tensor_reader is None, \
            "Cannot reset the data source of PyReader"
        with program_guard(Program(), Program()):
            feeder = DataFeeder(
                feed_list=self._feed_list, place=core.CPUPlace())
            paddle_reader = feeder.decorate_reader(reader, multi_devices=False)

        def __tensor_reader_impl__():
            for slots in paddle_reader():
                yield [slots[var.name] for var in self._feed_list]

        self.set_tensor_reader(__tensor_reader_impl__)

    def set_tensor_reader(self, reader):
        assert self._tensor_reader is None, \
            "Cannot reset the data source of PyReader"
        self._tensor_reader = reader
S
sneaxiy 已提交
128
        self._reset()