提交 0c75fc9c 编写于 作者: S Superjom

fix infer

上级 1506c808
......@@ -94,7 +94,7 @@ class Inferer(object):
def __init__(self, param_path):
logger.info("create DSSM model")
cost, prediction, label = DSSM(
prediction = DSSM(
dnn_dims=layer_dims,
vocab_sizes=[
len(load_dic(path))
......@@ -104,7 +104,8 @@ class Inferer(object):
model_arch=args.model_arch,
share_semantic_generator=args.share_network_between_source_target,
class_num=args.class_num,
share_embed=args.share_embed)()
share_embed=args.share_embed,
is_infer=True)()
# load parameter
logger.info("load model parameters from %s" % param_path)
......
......@@ -219,7 +219,7 @@ class DSSM(object):
# but this operator is not supported currently.
# so AUC will not used.
return cost, None, label
return None, [left_score, right_score], label
return right_score
def _build_classification_or_regression_model(self, is_classification):
'''
......@@ -275,4 +275,4 @@ class DSSM(object):
if not self.is_infer:
return cost, prediction, label
return None, prediction, label
return prediction
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册