From e4b5d10c22bbe795a1901fc3dc69ead0b5569a3b Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 26 Jun 2017 11:42:04 +0800 Subject: [PATCH] Change inference script to avoid loading model each batch. --- hsigmoid/infer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/hsigmoid/infer.py b/hsigmoid/infer.py index 32000238..ff080ad7 100644 --- a/hsigmoid/infer.py +++ b/hsigmoid/infer.py @@ -41,9 +41,8 @@ def decode_res(infer_res, dict_size): return predict_lbls -def predict(batch_ins, idx_word_dict, dict_size, prediction_layer, parameters): - infer_res = paddle.infer( - output_layer=prediction_layer, parameters=parameters, input=batch_ins) +def predict(batch_ins, idx_word_dict, dict_size, inferer): + infer_res = inferer.infer(input=batch_ins) predict_lbls = decode_res(infer_res, dict_size) predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] # map to word @@ -66,6 +65,8 @@ def main(model_path): with gzip.open(model_path, "r") as f: parameters = paddle.parameters.Parameters.from_tar(f) + inferer = paddle.inference.Inference( + output_layer=prediction_layer, parameters=parameters) idx_word_dict = dict((v, k) for k, v in word_dict.items()) batch_size = 64 batch_ins = [] @@ -74,13 +75,11 @@ def main(model_path): for ins in ins_iter(): batch_ins.append(ins[:-1]) if len(batch_ins) == batch_size: - predict(batch_ins, idx_word_dict, dict_size, prediction_layer, - parameters) + predict(batch_ins, idx_word_dict, dict_size, inferer) batch_ins = [] if len(batch_ins) > 0: - predict(batch_ins, idx_word_dict, dict_size, prediction_layer, - parameters) + predict(batch_ins, idx_word_dict, dict_size, inferer) if __name__ == "__main__": -- GitLab