seqToseq预测输入数据时报错
Created by: OleNet
报错信息如下:
I1227 20:44:49.066433 30847 Util.cpp:158] commandline: --use_gpu=0
I1227 20:44:49.066570 30847 Util.cpp:132] Calling runInitFunctions
I1227 20:44:49.066891 30847 Util.cpp:146] Call runInitFunctions done.
File "<string>", line 13
paddle version:
^
SyntaxError: invalid syntax
[INFO 2016-12-27 20:44:49,263 networks.py:1466] The input order is [source_language_word, sent_id]
[INFO 2016-12-27 20:44:49,263 networks.py:1472] The output order is [__beam_search_predict__]
I1227 20:44:49.476338 30847 GradientMachine.cpp:124] Loading parameters from /home/liujiaxiang/paddle_predit/exp1/thirdparty/pass-00017
Traceback (most recent call last):
File "model.py", line 57, in <module>
result = model.predict('头疼怎么办')
File "model.py", line 51, in predict
result = self.network.forwardTest([inArg])
File "/home/liujiaxiang/paddle_internal_release_tools/idl/paddle/output/python27-gcc482/lib/python2.7/site-packages/py_paddle/util.py", line 146, in forwardTest
self.forward(inArgs, outArgs, swig_paddle.PASS_TEST)
File "/home/liujiaxiang/paddle_internal_release_tools/idl/paddle/output/python27-gcc482/lib/python2.7/site-packages/py_paddle/swig_paddle.py", line 1342, in forward
return _swig_paddle.GradientMachine_forward(self, *args)
TypeError: in method 'GradientMachine_forward', argument 2 of type 'Arguments const &'
我的使用方式如下:
class Model():
def __init__(self, pass_path=None, trg_lang_dict=None):
if not pass_path:
pass_path = PASS_PATH
self.trg_dict = dict()
for line_count, line in enumerate(open(trg_lang_dict, "r")):
self.trg_dict[line.strip()] = line_count
self.network = self.init_network(pass_path)
def init_network(self, pass_path):
swig_paddle.initPaddle("--use_gpu=0")
conf = parse_config(pass_path+"/trainer_config.conf", "is_predict=1")
network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
assert isinstance(network, swig_paddle.GradientMachine) # For code hint.
network.loadParameters(pass_path)
return network
def predict(self, sentence):
ids = dataprovider._get_ids(sentence, self.trg_dict)
self.converter = DataProviderConverter([integer_value_sequence(50000)])
inArg = self.converter.convert([[ids]])
result = self.network.forwardTest([inArg])
return result