未验证 提交 8c1b58ae 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #309 from wang001/patch-1

fix bug for regression where is_prediction ==true
......@@ -92,8 +92,12 @@ def create_model(args,
name=task_name + "_cls_out_b",
initializer=fluid.initializer.Constant(0.)))
assert is_classify != is_regression, 'is_classify or is_regression must be true and only one of them can be true'
if is_prediction:
probs = fluid.layers.softmax(logits)
if is_classify:
probs = fluid.layers.softmax(logits)
else:
probs = logits
feed_targets_name = [
src_ids.name, sent_ids.name, pos_ids.name, input_mask.name
]
......@@ -101,7 +105,6 @@ def create_model(args,
feed_targets_name += [task_ids.name]
return pyreader, probs, feed_targets_name
assert is_classify != is_regression, 'is_classify or is_regression must be true and only one of them can be true'
num_seqs = fluid.layers.create_tensor(dtype='int64')
if is_classify:
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册