未验证 提交 841553e1 编写于 作者: W wopeizl 提交者: GitHub

use pyreader to read data in dygraph mode (#17314)

* use pyreader to read data

* add return_list to PyReader to support return value represented as list
上级 5436d666
......@@ -55,7 +55,7 @@ paddle.fluid.io.load_params (ArgSpec(args=['executor', 'dirname', 'main_program'
paddle.fluid.io.load_persistables (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '28df5bfe26ca7a077f91156abb0fe6d2'))
paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)), ('document', '89539e459eb959145f15c9c3e38fa97c'))
paddle.fluid.io.load_inference_model (ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '2f54d7c206b62f8c10f4f9d78c731cfd'))
paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable'], varargs=None, keywords=None, defaults=(True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable', 'return_list'], varargs=None, keywords=None, defaults=(None, None, True, True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.io.PyReader.decorate_batch_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '4a072de39998ee4e0de33fcec11325a6'))
paddle.fluid.io.PyReader.decorate_sample_generator (ArgSpec(args=['self', 'sample_generator', 'batch_size', 'drop_last', 'places'], varargs=None, keywords=None, defaults=(True, None)), ('document', '3db4b24d33fe4f711e303f9673dc5c6a'))
paddle.fluid.io.PyReader.decorate_sample_list_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '94adc0fb71c4b2ae6c3c74886c9cb898'))
......
......@@ -31,6 +31,7 @@ class MultiDeviceFeedReader {
public:
using ResultDictList =
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
using ResultList = std::vector<std::vector<framework::LoDTensor>>;
MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
......@@ -81,6 +82,21 @@ class MultiDeviceFeedReader {
return result;
}
ResultList ReadNextList() {
bool success = WaitFutures();
if (!success) {
return {};
}
ResultList result;
result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
result.emplace_back(std::move(ret_[i]));
}
ReadAsync();
return result;
}
void Reset() {
Shutdown();
Start();
......@@ -142,6 +158,8 @@ void BindReader(py::module *module) {
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
.def("read_next", &MultiDeviceFeedReader::ReadNext,
py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList,
py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset,
py::call_guard<py::gil_scoped_release>());
......
......@@ -21,7 +21,7 @@ import six
from six.moves import zip, range, xrange
import multiprocessing
from .framework import Variable, default_main_program
from .framework import Variable, default_main_program, _current_expected_place
__all__ = ['DataFeeder']
......@@ -432,3 +432,63 @@ class DataFeeder(object):
"not implemented")
return __reader_creator__
class NumpyToLoDTensorConverter(object):
def __init__(self, place):
self.place = place
self.data = []
self._reset()
def _reset(self):
self.data = []
def feed(self, data):
self.data.append(data)
def done(self):
arr = numpy.array(self.data)
t = core.LoDTensor()
t.set(arr, self.place)
self._reset()
return t
class ListTensorProvider(object):
def __init__(self, generator, places):
self.generator = generator
self.converters = []
self.places = []
if places:
if not isinstance(places, (list, tuple)):
places = [places]
assert len(
places) == 1, "dygraph mode CAN NOT specify multiple places."
for place in places:
if isinstance(place, (core.CUDAPlace, core.CPUPlace)):
self.places.append(place)
else:
raise ValueError(
"Please specify a valid place values such as core.CPUPlace or core.CUDAPlace"
)
if len(self.places) == 0:
self.places.append(_current_expected_place())
def _readData(self, iterable, places):
for place, each_sample in six.moves.zip(places, iterable):
for item in each_sample:
if len(self.converters) < len(item):
for i in item:
self.converters.append(NumpyToLoDTensorConverter(place))
for each_converter, each_slot in six.moves.zip(self.converters,
item):
each_converter.feed(each_slot)
yield [c.done() for c in self.converters]
def __call__(self):
item = []
for batch in self.generator():
item.append(batch)
if len(item) == len(self.places):
yield list(self._readData(item, self.places))
item = []
......@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core
from . import core, dygraph
import six
import warnings
import numpy as np
import threading
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program
import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider
from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator
......@@ -48,12 +51,13 @@ class PyReader(object):
Args:
feed_list (list(Variable)|tuple(Variable)): feed variable list.
The variables should be created by :code:`fluid.layers.data()`.
The variables should be created by :code:`fluid.layers.data()`.
it can be None under iterable mode.
capacity (int): capacity of the queue maintained in PyReader object.
use_double_buffer (bool): whether to use double_buffer_reader to
speed up data feeding.
iterable (bool): whether the created reader object is iterable.
return_list (bool): whether the return value presented as list.
Returns:
reader (Reader): the created reader object.
......@@ -124,7 +128,7 @@ class PyReader(object):
return reader
image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32')
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True)
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=False)
user_defined_reader = reader_creator_random_image(784, 784)
reader.decorate_sample_list_generator(
......@@ -138,26 +142,79 @@ class PyReader(object):
for data in reader():
executor.run(feed=data)
3. If return_list=True, the return values would be presented as list instead of dict`.
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
EPOCH_NUM = 3
ITER_NUM = 5
BATCH_SIZE = 10
def reader_creator_random_image(height, width):
def reader():
for i in range(ITER_NUM):
yield np.random.uniform(low=0, high=255, size=[height, width]),
return reader
image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32')
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=True)
user_defined_reader = reader_creator_random_image(784, 784)
reader.decorate_sample_list_generator(
paddle.batch(user_defined_reader, batch_size=BATCH_SIZE),
fluid.core.CPUPlace())
# definition of network is omitted
executor = fluid.Executor(fluid.core.CPUPlace())
executor.run(fluid.default_main_program())
for _ in range(EPOCH_NUM):
for data in reader():
executor.run(feed={"image": data[0]})
"""
unique_name_generator = UniqueNameGenerator()
def __init__(self,
feed_list,
capacity,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=False):
iterable=True,
return_list=False):
self._tensor_reader = None
self._thread = None
self._iterable = iterable
self._feed_list = feed_list
if not capacity:
raise ValueError("Please give value to capacity.")
# force to use iterable mode under dygraph mode
if in_dygraph_mode():
if not iterable:
warnings.warn(
"Please NOTE: dygraph can support iterable mode only.")
self._iterable = True
if not return_list:
warnings.warn(
"Please NOTE: dygraph can support return as list only.")
self._return_list = True
else:
self._iterable = iterable
self._return_list = return_list
if not self._feed_list:
raise Exception("Feed list must be given under static mode.")
self._use_double_buffer = use_double_buffer
self._capacity = capacity
self._feed_list = feed_list
if not self._iterable:
self._init_non_iterable()
def _init_iterable(self, places):
self._var_names = [v.name for v in self._feed_list]
if in_dygraph_mode():
self._var_names = []
else:
self._var_names = [v.name for v in self._feed_list]
self._places = _convert_places(places)
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(),
self._capacity)
......@@ -240,6 +297,7 @@ class PyReader(object):
def __init__(self, reader):
self._reader = reader._reader
self._reset = reader._reset
self._return_list = reader._return_list
def __iter__(self):
return self
......@@ -248,12 +306,28 @@ class PyReader(object):
return self.next()
def next(self):
ret = self._reader.read_next()
if ret:
return ret
if not in_dygraph_mode():
if self._return_list:
ret = self._reader.read_next_list()
ret = ret[0] if ret is not None and len(
ret) > 0 else None
else:
ret = self._reader.read_next()
if ret:
return ret
else:
self._reset()
raise StopIteration
else:
self._reset()
raise StopIteration
ret = self._reader.read_next_list()
if ret and ret[0]:
return [
dygraph.base.to_variable(np.array(v))
for v in ret[0]
]
else:
self._reset()
raise StopIteration
self._start()
return Iterator(self)
......@@ -293,8 +367,9 @@ class PyReader(object):
break
'''
assert not self._iterable, "start() cannot be called when PyReader is iterable"
self._start()
if not in_dygraph_mode():
assert not self._iterable, "start() cannot be called when PyReader is iterable"
self._start()
def reset(self):
'''
......@@ -327,8 +402,9 @@ class PyReader(object):
break
'''
assert not self._iterable, "reset() cannot be called when PyReader is iterable"
self._reset()
if not in_dygraph_mode():
assert not self._iterable, "reset() cannot be called when PyReader is iterable"
self._reset()
def _start(self):
def __thread_main__():
......@@ -415,27 +491,35 @@ class PyReader(object):
'''
assert batch_size > 0, "batch_size must be larger than 0"
has_lod = False
for f in self._feed_list:
if f.lod_level != 0:
has_lod = True
break
if has_lod:
if not in_dygraph_mode():
has_lod = False
for f in self._feed_list:
if f.lod_level != 0:
has_lod = True
break
if has_lod:
self.decorate_sample_list_generator(
paddle.batch(
sample_generator,
batch_size=batch_size,
drop_last=drop_last),
places=places)
else:
reader = BatchedTensorProvider(
feed_list=self._feed_list,
place=core.CPUPlace(),
batch_size=batch_size,
generator=sample_generator,
drop_last=drop_last)
self.decorate_batch_generator(reader, places=places)
else:
self.decorate_sample_list_generator(
paddle.batch(
sample_generator,
batch_size=batch_size,
drop_last=drop_last),
places=places)
else:
reader = BatchedTensorProvider(
feed_list=self._feed_list,
place=core.CPUPlace(),
batch_size=batch_size,
generator=sample_generator,
drop_last=drop_last)
self.decorate_batch_generator(reader, places=places)
def decorate_sample_list_generator(self, reader, places=None):
'''
......@@ -488,14 +572,22 @@ class PyReader(object):
'''
assert self._tensor_reader is None, \
"Cannot reset the data source of PyReader"
with program_guard(Program(), Program()):
feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(reader, multi_devices=False)
def __tensor_reader_impl__():
for slots in paddle_reader():
yield [slots[var.name] for var in self._feed_list]
if not in_dygraph_mode():
with program_guard(Program(), Program()):
feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(
reader, multi_devices=False)
def __tensor_reader_impl__():
for slots in paddle_reader():
yield [slots[var.name] for var in self._feed_list]
else:
provider = ListTensorProvider(reader, places)
def __tensor_reader_impl__():
for slots in provider():
yield slots[0]
self.decorate_batch_generator(__tensor_reader_impl__, places)
......
......@@ -80,19 +80,21 @@ def main():
train_reader.start()
try:
while True:
print 'train_loss', numpy.array(
trainer.run(fetch_list=[loss.name]))
print(
'train_loss',
numpy.array(trainer.run(fetch_list=[loss.name])))
except fluid.core.EOFException:
print 'End of epoch', epoch_id
print('End of epoch', epoch_id)
train_reader.reset()
test_reader.start()
try:
while True:
print 'test loss', numpy.array(
tester.run(fetch_list=[test_loss.name]))
print(
'test loss',
numpy.array(tester.run(fetch_list=[test_loss.name])))
except fluid.core.EOFException:
print 'End of testing'
print('End of testing')
test_reader.reset()
......
......@@ -18,7 +18,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid import Conv2D, Pool2D, FC
from paddle.fluid import Conv2D, Pool2D, FC, core
from paddle.fluid.dygraph.base import to_variable
......@@ -99,9 +99,19 @@ class MNIST(fluid.Layer):
class TestDygraphCheckpoint(unittest.TestCase):
def reader_decorator(self, reader):
def _reader_imple():
for item in reader():
image = np.array(item[0]).reshape(1, 28, 28)
label = np.array(item[1]).astype('int64').reshape(1)
yield image, label
return _reader_imple
def test_save_load_persistables(self):
seed = 90
epoch_num = 1
batch_size = 128
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
......@@ -109,22 +119,21 @@ class TestDygraphCheckpoint(unittest.TestCase):
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
paddle.batch(
self.reader_decorator(paddle.dataset.mnist.train()),
batch_size=batch_size,
drop_last=True),
places=fluid.CPUPlace())
dy_param_init_value = {}
step = 0
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
for batch_id, data in enumerate(batch_py_reader()):
img = data[0]
label = data[1]
label.stop_gradient = True
cost = mnist(img)
......@@ -153,9 +162,7 @@ class TestDygraphCheckpoint(unittest.TestCase):
self.assertTrue(np.isfinite(value.numpy().all()))
self.assertFalse(np.isnan(value.numpy().any()))
step += 1
if step > 10:
if batch_id > 10:
break
......
......@@ -105,30 +105,45 @@ class MNIST(fluid.dygraph.Layer):
class TestImperativeMnist(unittest.TestCase):
def reader_decorator(self, reader):
def _reader_imple():
for item in reader():
image = np.array(item[0]).reshape(1, 28, 28)
label = np.array(item[1]).astype('int64').reshape(1)
yield image, label
return _reader_imple
def test_mnist_float32(self):
seed = 90
epoch_num = 1
batch_size = 128
batch_num = 50
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
paddle.batch(
self.reader_decorator(paddle.dataset.mnist.train()),
batch_size=batch_size,
drop_last=True),
places=fluid.CPUPlace())
mnist.train()
dy_param_init_value = {}
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num:
break
img = data[0]
dy_x_data = img.numpy()
label = data[1]
label.stop_gradient = True
cost = mnist(img)
......@@ -159,7 +174,9 @@ class TestImperativeMnist(unittest.TestCase):
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
paddle.dataset.mnist.train(),
batch_size=batch_size,
drop_last=True)
img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32')
......@@ -183,11 +200,14 @@ class TestImperativeMnist(unittest.TestCase):
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
if batch_id >= batch_num:
break
static_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape([128, 1])
[x[1] for x in data]).astype('int64').reshape(
[batch_size, 1])
fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
......
......@@ -48,29 +48,41 @@ class TestImperativeOptimizerBase(unittest.TestCase):
def get_optimizer(self):
raise NotImplementedError()
def reader_decorator(self, reader):
def _reader_imple():
for item in reader():
image = np.array(item[0]).reshape(1, 28, 28)
label = np.array(item[1]).astype('int64').reshape(1)
yield image, label
return _reader_imple
def _check_mlp(self):
seed = 90
batch_size = 128
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mlp = MLP('mlp')
optimizer = self.get_optimizer()
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
paddle.batch(
self.reader_decorator(paddle.dataset.mnist.train()),
batch_size=batch_size,
drop_last=True),
places=fluid.CPUPlace())
dy_param_init_value = {}
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= self.batch_num:
break
dy_x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
img = data[0]
label = data[1]
label._stop_gradient = True
cost = mlp(img)
......
......@@ -227,6 +227,15 @@ class ResNet(fluid.Layer):
class TestDygraphResnet(unittest.TestCase):
def reader_decorator(self, reader):
def _reader_imple():
for item in reader():
doc = np.array(item[0]).reshape(3, 224, 224)
label = np.array(item[1]).astype('int64').reshape(1)
yield doc, label
return _reader_imple
def test_resnet_float32(self):
seed = 90
......@@ -242,25 +251,26 @@ class TestDygraphResnet(unittest.TestCase):
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
paddle.batch(
self.reader_decorator(
paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size,
drop_last=True),
places=fluid.CPUPlace())
dy_param_init_value = {}
for param in resnet.parameters():
dy_param_init_value[param.name] = param.numpy()
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num:
break
dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
img = data[0]
label = data[1]
label.stop_gradient = True
out = resnet(img)
......
......@@ -311,6 +311,15 @@ class SeResNeXt(fluid.dygraph.Layer):
class TestImperativeResneXt(unittest.TestCase):
def reader_decorator(self, reader):
def _reader_imple():
for item in reader():
doc = np.array(item[0]).reshape(3, 224, 224)
label = np.array(item[1]).astype('int64').reshape(1)
yield doc, label
return _reader_imple
def test_se_resnext_float32(self):
seed = 90
......@@ -326,29 +335,28 @@ class TestImperativeResneXt(unittest.TestCase):
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size,
drop_last=True)
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
paddle.batch(
self.reader_decorator(
paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size,
drop_last=True),
places=fluid.CPUPlace())
dy_param_init_value = {}
for param in se_resnext.parameters():
dy_param_init_value[param.name] = param.numpy()
for epoch_id in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(batch_py_reader()):
if batch_id >= batch_num and batch_num != -1:
break
dy_x_data = np.array(
[x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
img = data[0]
label = data[1]
label.stop_gradient = True
label.stop_gradient = True
out = se_resnext(img)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
class TestPyReader(unittest.TestCase):
def setUp(self):
self.batch_size = 32
self.epoch_num = 2
self.sample_num = 10
def test_returnlist(self):
def reader_creator_random_image(height, width):
def reader():
for i in range(self.sample_num):
yield np.random.uniform(
low=0, high=255, size=[height, width]),
return reader
for return_list in [True, False]:
with fluid.program_guard(fluid.Program(), fluid.Program()):
image = fluid.layers.data(
name='image', shape=[784, 784], dtype='float32')
reader = fluid.io.PyReader(
feed_list=[image],
capacity=4,
iterable=True,
return_list=return_list)
user_defined_reader = reader_creator_random_image(784, 784)
reader.decorate_sample_list_generator(
paddle.batch(
user_defined_reader, batch_size=self.batch_size),
fluid.core.CPUPlace())
# definition of network is omitted
executor = fluid.Executor(fluid.core.CPUPlace())
executor.run(fluid.default_main_program())
for _ in range(self.epoch_num):
for data in reader():
if return_list:
executor.run(feed={"image": data[0]})
else:
executor.run(feed=data)
with fluid.dygraph.guard():
batch_py_reader = fluid.io.PyReader(
feed_list=[
np.empty(
[self.batch_size, 784, 784], dtype='float32')
],
capacity=2,
use_double_buffer=True,
return_list=return_list)
user_defined_reader = reader_creator_random_image(784, 784)
batch_py_reader.decorate_sample_generator(
user_defined_reader,
batch_size=self.batch_size,
places=fluid.core.CPUPlace())
for epoch in range(self.epoch_num):
for _, data in enumerate(batch_py_reader()):
# empty network
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册