提交 b102d256 编写于 作者: J JiabinYang

polish code

上级 db405c78
......@@ -7,7 +7,6 @@ from Queue import PriorityQueue
import logging
import argparse
from sklearn.metrics.pairwise import cosine_similarity
from paddle.fluid.executor import global_scope
word_to_id = dict()
id_to_word = dict()
......@@ -42,6 +41,12 @@ def parse_args():
required=False,
default=False,
help='if using infer_once, (default: False)')
parser.add_argument(
'--infer_during_train',
action='store_true',
required=False,
default=True,
help='if using infer_during_train, (default: True)')
return parser.parse_args()
......@@ -78,40 +83,17 @@ def build_test_case(emb):
[emb5, desc5]]
def inference_test(model_dir, args):
def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path)
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
logger.info("model_dir is: {}".format(model_dir + "/"))
with fluid.scope_guard(Scope):
inference_prog()
fluid.io.load_persistables(executor=exe, dirname=model_dir + "/")
emb = np.array(Scope.find_var("embeding").get_tensor())
test_cases = build_test_case(emb)
logger.info("inference result: ====================")
for case in test_cases:
pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1]))
pq_tmps = list()
for i in range(args.rank_num):
pq_tmps.append(pq.get())
for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
pq_tmps[len(pq_tmps) - 1 - i].id], pq_tmps[len(pq_tmps) - 1
- i].priority))
del pq_tmps[:]
def infer_with_in_train(model_dir, rank_num, dict_path):
BuildWord_IdMap(dict_path)
emb = np.array(global_scope().find_var("embeding").get_tensor())
emb = np.array(scope.find_var("embeding").get_tensor())
test_cases = build_test_case(emb)
logger.info("inference result: ====================")
for case in test_cases:
pq = topK(rank_num, emb, case[0])
pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1]))
pq_tmps = list()
for i in range(rank_num):
for i in range(args.rank_num):
pq_tmps.append(pq.get())
for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
......@@ -149,6 +131,9 @@ def topK(k, emb, test_emb):
def infer_during_train(args):
model_file_list = list()
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
while True:
time.sleep(1)
current_list = os.listdir(args.model_output_dir)
......@@ -167,15 +152,24 @@ def infer_during_train(args):
model_dir = args.model_output_dir + "/" + model
if os.path.exists(model_dir + "/_success"):
logger.info("using models from " + model_dir)
inference_test(model_dir, args)
with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=model_dir + "/")
inference_test(Scope, model_dir, args)
model_file_list = current_list
def infer_once(args):
if os.path.exists(args.model_output_dir + "/_success"
): # check models file has already been finished
# check models file has already been finished
if os.path.exists(args.model_output_dir + "/_success"):
logger.info("using models from " + args.model_output_dir)
inference_test(args.model_output_dir, args)
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=args.model_output_dir + "/")
inference_test(Scope, args.model_output_dir, args)
if __name__ == '__main__':
......@@ -183,5 +177,5 @@ if __name__ == '__main__':
# while setting infer_once please specify the dir to models file with --model_output_dir
if args.infer_once:
infer_once(args)
else:
if args.infer_during_train:
infer_during_train(args)
......@@ -12,10 +12,11 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
from paddle.fluid.executor import global_scope
import reader
from network_conf import skip_gram_word2vec
from infer import infer_with_train
from infer import inference_test
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
......@@ -206,8 +207,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
infer_with_in_train(model_dir, args.rank_num,
args.dict_path)
inference_test(global_scope(), model_dir, args)
batch_id += 1
except fluid.core.EOFException:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册