提交 b102d256 编写于 作者: J JiabinYang

polish code

上级 db405c78
...@@ -7,7 +7,6 @@ from Queue import PriorityQueue ...@@ -7,7 +7,6 @@ from Queue import PriorityQueue
import logging import logging
import argparse import argparse
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from paddle.fluid.executor import global_scope
word_to_id = dict() word_to_id = dict()
id_to_word = dict() id_to_word = dict()
...@@ -42,6 +41,12 @@ def parse_args(): ...@@ -42,6 +41,12 @@ def parse_args():
required=False, required=False,
default=False, default=False,
help='if using infer_once, (default: False)') help='if using infer_once, (default: False)')
parser.add_argument(
'--infer_during_train',
action='store_true',
required=False,
default=True,
help='if using infer_during_train, (default: True)')
return parser.parse_args() return parser.parse_args()
...@@ -78,40 +83,17 @@ def build_test_case(emb): ...@@ -78,40 +83,17 @@ def build_test_case(emb):
[emb5, desc5]] [emb5, desc5]]
def inference_test(model_dir, args): def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path) BuildWord_IdMap(args.dict_path)
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
logger.info("model_dir is: {}".format(model_dir + "/")) logger.info("model_dir is: {}".format(model_dir + "/"))
with fluid.scope_guard(Scope): emb = np.array(scope.find_var("embeding").get_tensor())
inference_prog()
fluid.io.load_persistables(executor=exe, dirname=model_dir + "/")
emb = np.array(Scope.find_var("embeding").get_tensor())
test_cases = build_test_case(emb)
logger.info("inference result: ====================")
for case in test_cases:
pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1]))
pq_tmps = list()
for i in range(args.rank_num):
pq_tmps.append(pq.get())
for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
pq_tmps[len(pq_tmps) - 1 - i].id], pq_tmps[len(pq_tmps) - 1
- i].priority))
del pq_tmps[:]
def infer_with_in_train(model_dir, rank_num, dict_path):
BuildWord_IdMap(dict_path)
emb = np.array(global_scope().find_var("embeding").get_tensor())
test_cases = build_test_case(emb) test_cases = build_test_case(emb)
logger.info("inference result: ====================") logger.info("inference result: ====================")
for case in test_cases: for case in test_cases:
pq = topK(rank_num, emb, case[0]) pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1])) logger.info("Test result for {}".format(case[1]))
pq_tmps = list() pq_tmps = list()
for i in range(rank_num): for i in range(args.rank_num):
pq_tmps.append(pq.get()) pq_tmps.append(pq.get())
for i in range(len(pq_tmps)): for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[ logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
...@@ -149,6 +131,9 @@ def topK(k, emb, test_emb): ...@@ -149,6 +131,9 @@ def topK(k, emb, test_emb):
def infer_during_train(args): def infer_during_train(args):
model_file_list = list() model_file_list = list()
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
while True: while True:
time.sleep(1) time.sleep(1)
current_list = os.listdir(args.model_output_dir) current_list = os.listdir(args.model_output_dir)
...@@ -167,15 +152,24 @@ def infer_during_train(args): ...@@ -167,15 +152,24 @@ def infer_during_train(args):
model_dir = args.model_output_dir + "/" + model model_dir = args.model_output_dir + "/" + model
if os.path.exists(model_dir + "/_success"): if os.path.exists(model_dir + "/_success"):
logger.info("using models from " + model_dir) logger.info("using models from " + model_dir)
inference_test(model_dir, args) with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=model_dir + "/")
inference_test(Scope, model_dir, args)
model_file_list = current_list model_file_list = current_list
def infer_once(args): def infer_once(args):
if os.path.exists(args.model_output_dir + "/_success" # check models file has already been finished
): # check models file has already been finished if os.path.exists(args.model_output_dir + "/_success"):
logger.info("using models from " + args.model_output_dir) logger.info("using models from " + args.model_output_dir)
inference_test(args.model_output_dir, args) exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=args.model_output_dir + "/")
inference_test(Scope, args.model_output_dir, args)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -183,5 +177,5 @@ if __name__ == '__main__': ...@@ -183,5 +177,5 @@ if __name__ == '__main__':
# while setting infer_once please specify the dir to models file with --model_output_dir # while setting infer_once please specify the dir to models file with --model_output_dir
if args.infer_once: if args.infer_once:
infer_once(args) infer_once(args)
else: if args.infer_during_train:
infer_during_train(args) infer_during_train(args)
...@@ -12,10 +12,11 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "" ...@@ -12,10 +12,11 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.executor import global_scope
import reader import reader
from network_conf import skip_gram_word2vec from network_conf import skip_gram_word2vec
from infer import infer_with_train from infer import inference_test
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid") logger = logging.getLogger("fluid")
...@@ -206,8 +207,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -206,8 +207,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
if batch_id % 1000 == 0 and batch_id != 0: if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str( model_dir = args.model_output_dir + '/batch-' + str(
batch_id) batch_id)
infer_with_in_train(model_dir, args.rank_num, inference_test(global_scope(), model_dir, args)
args.dict_path)
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册