未验证 提交 1458cc0c 编写于 作者: L liym27 提交者: GitHub

Fix bug: Don't check dims if contain_unknown_dim of cross_entropy_grad_op in compile time (#25221)

上级 cb0472b0
......@@ -145,11 +145,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
"But received: Y@Grad's rank is [%d], Y's rank is [%d]",
dy_dims.size(), label_dims.size()));
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(dy_dims) <= 0)) {
check = false;
}
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(dy_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册