未验证 提交 f4846bf3 编写于 作者: Q qingqing01 提交者: GitHub

loosly check in the InferShape of cross_entropy_op. (#15863)

* loosly check in cross_entropy_op when soft_label is True
* Add Runtime assertion in backward infer_shape check.
* Skip InferShape check when un-know the input dimensions
上级 2c5c7b2a
...@@ -32,14 +32,23 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -32,14 +32,23 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(), PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank."); "Input(X) and Input(Label) shall have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), bool check = true;
framework::slice_ddim(label_dims, 0, rank - 1), if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
"Input(X) and Input(Label) shall have the same shape " framework::product(label_dims) <= 0)) {
"except the last dimension."); check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) { if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], if (check) {
"If Attr(soft_label) == true, the last dimension of " PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"Input(X) and Input(Label) should be equal."); "If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else { } else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of " "If Attr(softLabel) == false, the last dimension of "
...@@ -82,20 +91,32 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -82,20 +91,32 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Y@Grad) and Input(X) should have the same rank."); "Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank, PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank."); "Input(Label) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), bool check = true;
"The Input(X) and Input(Label) should have the same " if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
"shape except the last dimension."); framework::product(label_dims) <= 0)) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), check = false;
framework::slice_ddim(dy_dims, 0, rank - 1), }
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."); if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1."); "The last dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("soft_label")) { if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], if (check) {
"When Attr(soft_label) == true, the last dimension of " PADDLE_ENFORCE_EQ(
"Input(X) and Input(Label) should be equal."); x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else { } else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of " "When Attr(soft_label) == false, the last dimension of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册