提交 b31baa95 编写于 作者: Y yangyaming

change infer logic to avoid loading model each batch

上级 7f0f5a1f
......@@ -36,9 +36,8 @@ def decode_res(infer_res, dict_size):
return predict_lbls
def predict(batch_ins, idx_word_dict, dict_size, prediction_layer, parameters):
infer_res = paddle.infer(
output_layer=prediction_layer, parameters=parameters, input=batch_ins)
def predict(batch_ins, idx_word_dict, dict_size, inferer):
infer_res = inferer.infer(input=batch_ins)
predict_lbls = decode_res(infer_res, dict_size)
predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] # map to word
......@@ -62,6 +61,8 @@ def main():
with gzip.open('./models/model_pass_00000.tar.gz') as f:
parameters = paddle.parameters.Parameters.from_tar(f)
inferer = paddle.inference.Inference(
output_layer=prediction_layer, parameters=parameters)
idx_word_dict = dict((v, k) for k, v in word_dict.items())
batch_size = 64
batch_ins = []
......@@ -70,13 +71,11 @@ def main():
for ins in ins_iter():
batch_ins.append(ins[:-1])
if len(batch_ins) == batch_size:
predict(batch_ins, idx_word_dict, dict_size, prediction_layer,
parameters)
predict(batch_ins, idx_word_dict, dict_size, inferer)
batch_ins = []
if len(batch_ins) > 0:
predict(batch_ins, idx_word_dict, dict_size, prediction_layer,
parameters)
predict(batch_ins, idx_word_dict, dict_size, inferer)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册