未验证 提交 74a28f5e 编写于 作者: A Aurelius84 提交者: GitHub

fix fill_constant shape with -1 and enhance cross_entropy test=develop (#20722)

上级 48a774c7
......@@ -136,8 +136,8 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
"Input(Y@Grad) and Input(Y) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(dy_dims) <= 0)) {
check = false;
}
......
......@@ -1251,7 +1251,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
op_desc = _create_op_desc_("fill_constant",
{"ShapeTensor": [target_shape.name]},
{"Out": [grad_name]}, {
"shape": [],
"shape": target.shape,
"value": 1.0,
"dtype": target.dtype,
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册