diff --git a/dssm/infer.py b/dssm/infer.py index 9192d7e19af10da9edba5353747e11ab6c786f0f..bf5abb0a8d75bd5b4610c22ece89d53b60cc09a6 100644 --- a/dssm/infer.py +++ b/dssm/infer.py @@ -94,7 +94,7 @@ class Inferer(object): def __init__(self, param_path): logger.info("create DSSM model") - cost, prediction, label = DSSM( + prediction = DSSM( dnn_dims=layer_dims, vocab_sizes=[ len(load_dic(path)) @@ -104,7 +104,8 @@ class Inferer(object): model_arch=args.model_arch, share_semantic_generator=args.share_network_between_source_target, class_num=args.class_num, - share_embed=args.share_embed)() + share_embed=args.share_embed, + is_infer=True)() # load parameter logger.info("load model parameters from %s" % param_path) diff --git a/dssm/network_conf.py b/dssm/network_conf.py index 430db232db7734368d1261adcaef87a1360c0f28..7be413fcc0d2c41d318015e580741333403e40ff 100644 --- a/dssm/network_conf.py +++ b/dssm/network_conf.py @@ -219,7 +219,7 @@ class DSSM(object): # but this operator is not supported currently. # so AUC will not used. return cost, None, label - return None, [left_score, right_score], label + return right_score def _build_classification_or_regression_model(self, is_classification): ''' @@ -275,4 +275,4 @@ class DSSM(object): if not self.is_infer: return cost, prediction, label - return None, prediction, label + return prediction