提交 e4b5d10c 编写于 作者: Y yangyaming

Change inference script to avoid loading model each batch.

上级 4097a2cb
......@@ -41,9 +41,8 @@ def decode_res(infer_res, dict_size):
return predict_lbls
def predict(batch_ins, idx_word_dict, dict_size, prediction_layer, parameters):
infer_res = paddle.infer(
output_layer=prediction_layer, parameters=parameters, input=batch_ins)
def predict(batch_ins, idx_word_dict, dict_size, inferer):
infer_res = inferer.infer(input=batch_ins)
predict_lbls = decode_res(infer_res, dict_size)
predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] # map to word
......@@ -66,6 +65,8 @@ def main(model_path):
with gzip.open(model_path, "r") as f:
parameters = paddle.parameters.Parameters.from_tar(f)
inferer = paddle.inference.Inference(
output_layer=prediction_layer, parameters=parameters)
idx_word_dict = dict((v, k) for k, v in word_dict.items())
batch_size = 64
batch_ins = []
......@@ -74,13 +75,11 @@ def main(model_path):
for ins in ins_iter():
batch_ins.append(ins[:-1])
if len(batch_ins) == batch_size:
predict(batch_ins, idx_word_dict, dict_size, prediction_layer,
parameters)
predict(batch_ins, idx_word_dict, dict_size, inferer)
batch_ins = []
if len(batch_ins) > 0:
predict(batch_ins, idx_word_dict, dict_size, prediction_layer,
parameters)
predict(batch_ins, idx_word_dict, dict_size, inferer)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册