diff --git a/word_embedding/hsigmoid_predict.py b/word_embedding/hsigmoid_predict.py index 210f87ee103a2ac145e3c42cea536cd00d2994bb..287dbfea254b0d05a4d6b357535bfe95265c5ae7 100644 --- a/word_embedding/hsigmoid_predict.py +++ b/word_embedding/hsigmoid_predict.py @@ -36,9 +36,8 @@ def decode_res(infer_res, dict_size): return predict_lbls -def predict(batch_ins, idx_word_dict, dict_size, prediction_layer, parameters): - infer_res = paddle.infer( - output_layer=prediction_layer, parameters=parameters, input=batch_ins) +def predict(batch_ins, idx_word_dict, dict_size, inferer): + infer_res = inferer.infer(input=batch_ins) predict_lbls = decode_res(infer_res, dict_size) predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] # map to word @@ -62,6 +61,8 @@ def main(): with gzip.open('./models/model_pass_00000.tar.gz') as f: parameters = paddle.parameters.Parameters.from_tar(f) + inferer = paddle.inference.Inference( + output_layer=prediction_layer, parameters=parameters) idx_word_dict = dict((v, k) for k, v in word_dict.items()) batch_size = 64 batch_ins = [] @@ -70,13 +71,11 @@ def main(): for ins in ins_iter(): batch_ins.append(ins[:-1]) if len(batch_ins) == batch_size: - predict(batch_ins, idx_word_dict, dict_size, prediction_layer, - parameters) + predict(batch_ins, idx_word_dict, dict_size, inferer) batch_ins = [] if len(batch_ins) > 0: - predict(batch_ins, idx_word_dict, dict_size, prediction_layer, - parameters) + predict(batch_ins, idx_word_dict, dict_size, inferer) if __name__ == '__main__':