DataLoader中set_sample_generator丢数据的问题
Created by: shuoyin
测试代码如下,测试环境python2.7.13 + paddle 1.7,CPU
import paddle.fluid as fluid
def sample_generator():
def __reader__():
for i in range(10):
yield (i,i)
return __reader__
def sample_list_generator(batch_size=4):
def __reader__():
batch_data = []
for i in range(10):
batch_data.append((i,i))
if len(batch_data) == batch_size:
yield batch_data
batch_data = []
if len(batch_data) > 0:
yield batch_data
return __reader__
def net():
x = fluid.data(name='x', dtype='int32', shape=[None])
y = fluid.data(name='y', dtype='int32', shape=[None])
loader = fluid.io.DataLoader.from_generator(
feed_list=[x,y], capacity=2, use_double_buffer=True, iterable=True)
fetch_dict = {'x': x.name, 'y': y.name}
return loader, fetch_dict
loader, fetch_dict = net()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print "use sample_generator"
gen = sample_generator()
loader.set_sample_generator(gen,batch_size=4, places=place)
for data in loader:
print exe.run(program=fluid.default_main_program(), feed=data,
fetch_list=[fetch_dict['x'], fetch_dict['y']])
print "use sample_list_generator"
gen = sample_list_generator()
loader.set_sample_list_generator(gen,places=place)
for data in loader:
print exe.run(program=fluid.default_main_program(), feed=data,
fetch_list=[fetch_dict['x'], fetch_dict['y']])
运行之后,输出如下:
use sample_generator
[array([0, 1, 2, 3], dtype=int32), array([0, 1, 2, 3], dtype=int32)]
[array([4, 5, 6, 7], dtype=int32), array([4, 5, 6, 7], dtype=int32)]
use sample_list_generator
[array([0, 1, 2, 3], dtype=int32), array([0, 1, 2, 3], dtype=int32)]
[array([4, 5, 6, 7], dtype=int32), array([4, 5, 6, 7], dtype=int32)]
[array([8, 9], dtype=int32), array([8, 9], dtype=int32)]
这个问题还请paddle的同学解答一下