“d83b861a511937c8bb7527c59435ef2d0965a649”上不存在“develop/doc_cn/v1_api_tutorials/README.html”
提交 f503908d 编写于 作者: Z zhangwenhui03

fix net bug

上级 46629ec6
...@@ -171,7 +171,8 @@ def train_cross_entropy_network(vocab_size, neg_size, hid_size, drop_out=0.2): ...@@ -171,7 +171,8 @@ def train_cross_entropy_network(vocab_size, neg_size, hid_size, drop_out=0.2):
ele_mul = fluid.layers.elementwise_mul(emb_label_drop, gru) ele_mul = fluid.layers.elementwise_mul(emb_label_drop, gru)
red_sum = fluid.layers.reduce_sum(input=ele_mul, dim=1, keep_dim=True) red_sum = fluid.layers.reduce_sum(input=ele_mul, dim=1, keep_dim=True)
pre = fluid.layers.sequence_reshape(input=red_sum, new_dim=(neg_size + 1)) pre_ = fluid.layers.sequence_reshape(input=red_sum, new_dim=(neg_size + 1))
pre = fluid.layers.softmax(input=pre_)
cost = fluid.layers.cross_entropy(input=pre, label=pos_label) cost = fluid.layers.cross_entropy(input=pre, label=pos_label)
cost_sum = fluid.layers.reduce_sum(input=cost) cost_sum = fluid.layers.reduce_sum(input=cost)
......
...@@ -68,9 +68,11 @@ def train(): ...@@ -68,9 +68,11 @@ def train():
# Train program # Train program
if args.loss == 'bpr': if args.loss == 'bpr':
print('bpr loss')
src, pos_label, label, avg_cost = net.train_bpr_network( src, pos_label, label, avg_cost = net.train_bpr_network(
neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size) neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size)
else: else:
print('cross-entory loss')
src, pos_label, label, avg_cost = net.train_cross_entropy_network( src, pos_label, label, avg_cost = net.train_cross_entropy_network(
neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size) neg_size=args.neg_size, vocab_size=vocab_size, hid_size=hid_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册