提交 f503908d 编写于 作者: Z zhangwenhui03

fix net bug

上级 46629ec6
......@@ -171,7 +171,8 @@ def train_cross_entropy_network(vocab_size, neg_size, hid_size, drop_out=0.2):
ele_mul = fluid.layers.elementwise_mul(emb_label_drop, gru)
red_sum = fluid.layers.reduce_sum(input=ele_mul, dim=1, keep_dim=True)
pre = fluid.layers.sequence_reshape(input=red_sum, new_dim=(neg_size + 1))
pre_ = fluid.layers.sequence_reshape(input=red_sum, new_dim=(neg_size + 1))
pre = fluid.layers.softmax(input=pre_)
cost = fluid.layers.cross_entropy(input=pre, label=pos_label)
cost_sum = fluid.layers.reduce_sum(input=cost)
......
......@@ -68,9 +68,11 @@ def train():
# Train program
if args.loss == 'bpr':
print('bpr loss')
src, pos_label, label, avg_cost = net.train_bpr_network(
neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size)
else:
print('cross-entory loss')
src, pos_label, label, avg_cost = net.train_cross_entropy_network(
neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册