提交 58e7a556 编写于 作者: W wangxiao

fix cross_entropy

上级 c087d295
......@@ -75,8 +75,8 @@ class TaskParadigm(task_paradigm):
name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)))
if self._is_training:
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=label_ids)
loss = fluid.layers.cross_entropy(
input=logits, label=label_ids)
loss = layers.mean(loss)
return {"loss": loss}
else:
......
......@@ -79,8 +79,8 @@ class TaskParadigm(task_paradigm):
initializer=fluid.initializer.Constant(0.)))
if self._is_training:
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=labels, return_softmax=True)
ce_loss, probs = fluid.layers.cross_entropy(
input=logits, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
return {'loss': loss}
else:
......
......@@ -100,8 +100,8 @@ class TaskParadigm(task_paradigm):
is_bias=True)
if self._is_training:
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
mask_lm_loss = fluid.layers.cross_entropy(
input=fc_out, label=mask_label)
loss = fluid.layers.mean(mask_lm_loss)
return {'loss': loss}
else:
......
......@@ -100,9 +100,10 @@ class TaskParadigm(task_paradigm):
start_logits, end_logits = fluid.layers.unstack(x=logits, axis=0)
def _compute_single_loss(logits, positions):
"""Compute start/end loss for mrc model"""
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=positions)
"""Compute start/en
d loss for mrc model"""
loss = fluid.layers.cross_entropy(
input=logits, label=positions)
loss = fluid.layers.mean(x=loss)
return loss
......
......@@ -22,8 +22,6 @@ from paddle import fluid
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
print(rt_val)
print(attr)
if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val)
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册