NotImplementedError: Wrong number or type of arguments for overloaded function 'IVector_create'
Created by: huayong
出现的问题如title, 主要目的是实现分割任务,所以自己写了个reader生成器来读取自己的训练数据,代码如下:
def load_image(file, image_size=224, mean_file=None): im = cv2.imread(file) im = cv2.resize(im, (image_size, image_size)) im = np.array(im).astype(np.float32) mean = np.array([103.94, 116.78, 123.68]) im = im / 255.0 - mean / 255.0 im = im / 255.0 im = im.transpose((2, 0, 1)) im = im.astype(np.float32) return im
def load_label(file, label_size=224): label_data = cv2.imread(file, cv2.IMREAD_GRAYSCALE) label_data = cv2.resize(label_data, (label_size, label_size)) label_data = np.array(label_data) label_data = label_data.astype(int) return label_data
def reader_creator(sub_name):
def reader():
set_file = SET_FILE.format(sub_name)
sets = [line.strip() for line in open(set_file, 'r')]
for line in sets:
line = line.strip()
data_file = DATA_FILE.format(line)
label_file = LABEL_FILE.format(line)
data = load_image(data_file)
label = load_label(label_file)
data = np.array(data)
label = np.array(label)
yield data, label
return reader
但是在train过程中出现错误如下: Traceback (most recent call last): File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1599, in globals = debugger.run(setup['file'], None, None, is_module) File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1026, in run pydev_imports.execfile(file, globals, locals) # execute the script File "/Volumes/Huayong/baidu/paddlepaddle/Seg/seg_train.py", line 96, in train(net, class_num) File "/Volumes/Huayong/baidu/paddlepaddle/Seg/seg_train.py", line 79, in train 'label': 1}) File "/Library/Python/2.7/site-packages/paddle/v2/trainer.py", line 169, in train in_args = feeder(data_batch) File "/Library/Python/2.7/site-packages/py_paddle/dataprovider_converter.py", line 282, in call return self.convert(dat, argument) File "/Library/Python/2.7/site-packages/paddle/v2/data_feeder.py", line 133, in convert return DataProviderConverter.convert(self, reorder_data(dat), argument) File "/Library/Python/2.7/site-packages/py_paddle/dataprovider_converter.py", line 277, in convert scanner.finish_scan(argument) File "/Library/Python/2.7/site-packages/py_paddle/dataprovider_converter.py", line 211, in finish_scan ids = swig_paddle.IVector.create(self.ids, self.data_in_gpu) File "/Library/Python/2.7/site-packages/py_paddle/swig_paddle.py", line 1344, in create return _swig_paddle.IVector_create(args) NotImplementedError: Wrong number or type of arguments for overloaded function 'IVector_create'. Possible C/C++ prototypes are: IVector::create(std::vector< int,std::allocator< int > > const &,bool) IVector::create(std::vector< int,std::allocator< int > > const &) train代码如下: def train(net, class_num): lbl = paddle.layer.data( name="label", type=paddle.data_type.integer_value(224224))
cost = paddle.layer.cross_entropy_cost(input=net,
label=lbl)
# Create parameters
parameters = paddle.parameters.create(cost)
# Create optimizer
# momentum_optimizer = paddle.optimizer.Momentum(
# momentum=0.9,
# regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
# learning_rate=0.1 / 128.0,
# learning_rate_decay_a=0.1,
# learning_rate_decay_b=50000 * 100,
# learning_rate_schedule='discexp')
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=adam_optimizer)
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 2 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
# save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
result = trainer.test(
reader=paddle.batch(
seg_reader.train(), batch_size=1),
feeding={'image': 0,
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
seg_reader.train(), buf_size=50000),
batch_size=1),
num_passes=1,
event_handler=event_handler,
feeding={'image': 0,
'label': 1})
return
pycharm debug模式下查看读取的数据应该没有问题