提交 0c75fc9c 编写于 作者: S Superjom

fix infer

上级 1506c808
...@@ -94,7 +94,7 @@ class Inferer(object): ...@@ -94,7 +94,7 @@ class Inferer(object):
def __init__(self, param_path): def __init__(self, param_path):
logger.info("create DSSM model") logger.info("create DSSM model")
cost, prediction, label = DSSM( prediction = DSSM(
dnn_dims=layer_dims, dnn_dims=layer_dims,
vocab_sizes=[ vocab_sizes=[
len(load_dic(path)) len(load_dic(path))
...@@ -104,7 +104,8 @@ class Inferer(object): ...@@ -104,7 +104,8 @@ class Inferer(object):
model_arch=args.model_arch, model_arch=args.model_arch,
share_semantic_generator=args.share_network_between_source_target, share_semantic_generator=args.share_network_between_source_target,
class_num=args.class_num, class_num=args.class_num,
share_embed=args.share_embed)() share_embed=args.share_embed,
is_infer=True)()
# load parameter # load parameter
logger.info("load model parameters from %s" % param_path) logger.info("load model parameters from %s" % param_path)
......
...@@ -219,7 +219,7 @@ class DSSM(object): ...@@ -219,7 +219,7 @@ class DSSM(object):
# but this operator is not supported currently. # but this operator is not supported currently.
# so AUC will not used. # so AUC will not used.
return cost, None, label return cost, None, label
return None, [left_score, right_score], label return right_score
def _build_classification_or_regression_model(self, is_classification): def _build_classification_or_regression_model(self, is_classification):
''' '''
...@@ -275,4 +275,4 @@ class DSSM(object): ...@@ -275,4 +275,4 @@ class DSSM(object):
if not self.is_infer: if not self.is_infer:
return cost, prediction, label return cost, prediction, label
return None, prediction, label return prediction
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册