未验证 提交 74a28f5e 编写于 作者: A Aurelius84 提交者: GitHub

fix fill_constant shape with -1 and enhance cross_entropy test=develop (#20722)

上级 48a774c7
...@@ -136,8 +136,8 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { ...@@ -136,8 +136,8 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
"Input(Y@Grad) and Input(Y) should have the same rank."); "Input(Y@Grad) and Input(Y) should have the same rank.");
bool check = true; bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || if ((!ctx->IsRuntime()) &&
framework::product(label_dims) <= 0)) { (framework::product(x_dims) <= 0 || framework::product(dy_dims) <= 0)) {
check = false; check = false;
} }
......
...@@ -1251,7 +1251,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1251,7 +1251,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
op_desc = _create_op_desc_("fill_constant", op_desc = _create_op_desc_("fill_constant",
{"ShapeTensor": [target_shape.name]}, {"ShapeTensor": [target_shape.name]},
{"Out": [grad_name]}, { {"Out": [grad_name]}, {
"shape": [], "shape": target.shape,
"value": 1.0, "value": 1.0,
"dtype": target.dtype, "dtype": target.dtype,
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册