提交 1b4538dc 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #212 from Superjom/fix_dssm

Fix dssm infer error with gru
......@@ -94,7 +94,7 @@ class Inferer(object):
def __init__(self, param_path):
logger.info("create DSSM model")
cost, prediction, label = DSSM(
prediction = DSSM(
dnn_dims=layer_dims,
vocab_sizes=[
len(load_dic(path))
......@@ -104,7 +104,8 @@ class Inferer(object):
model_arch=args.model_arch,
share_semantic_generator=args.share_network_between_source_target,
class_num=args.class_num,
share_embed=args.share_embed)()
share_embed=args.share_embed,
is_infer=True)()
# load parameter
logger.info("load model parameters from %s" % param_path)
......
......@@ -102,8 +102,7 @@ class DSSM(object):
'''
A GRU sentence vector learner.
'''
gru = paddle.layer.gru_memory(
input=emb, )
gru = paddle.networks.simple_gru(input=emb, size=256)
sent_vec = paddle.layer.last_seq(gru)
return sent_vec
......@@ -219,7 +218,7 @@ class DSSM(object):
# but this operator is not supported currently.
# so AUC will not used.
return cost, None, label
return None, [left_score, right_score], label
return right_score
def _build_classification_or_regression_model(self, is_classification):
'''
......@@ -275,4 +274,4 @@ class DSSM(object):
if not self.is_infer:
return cost, prediction, label
return None, prediction, label
return prediction
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册