未验证 提交 f4846bf3 编写于 作者: Q qingqing01 提交者: GitHub

loosly check in the InferShape of cross_entropy_op. (#15863)

* loosly check in cross_entropy_op when soft_label is True
* Add Runtime assertion in backward infer_shape check.
* Skip InferShape check when un-know the input dimensions
上级 2c5c7b2a
...@@ -32,14 +32,23 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -32,14 +32,23 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(), PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank."); "Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape " "Input(X) and Input(Label) shall have the same shape "
"except the last dimension."); "except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) { if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of " "If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
}
} else { } else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of " "If Attr(softLabel) == false, the last dimension of "
...@@ -82,6 +91,14 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -82,6 +91,14 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Y@Grad) and Input(X) should have the same rank."); "Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank, PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank."); "Input(Label) and Input(X) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same " "The Input(X) and Input(Label) should have the same "
...@@ -90,12 +107,16 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -90,12 +107,16 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
framework::slice_ddim(dy_dims, 0, rank - 1), framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same " "The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."); "shape except the last dimension.");
}
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1."); "The last dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("soft_label")) { if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of " "When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
}
} else { } else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of " "When Attr(soft_label) == false, the last dimension of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册