未验证 提交 8ad5f46f 编写于 作者: W Wang Yuanhong 提交者: GitHub

Update run_classifier.py

上级 a40349a4
......@@ -317,6 +317,28 @@ def train(conf_dict, args):
else:
test_auc = test_result
logging.info("AUC of test is %f" % test_auc)
if args.output_dir is not None:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, args.task_name)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
if args.task_mode == "pairwise":
feed_var_names = [left.name, pos_right.name]
target_vars = [left_feat, pos_score]
else:
feed_var_names = [
left.name,
right.name,
]
target_vars = [left_feat, pred]
fluid.io.save_inference_model(model_path, feed_var_names,
target_vars, executor,
infer_program)
logging.info("saving infer model in %s" % model_path)
def test(conf_dict, args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册