未验证 提交 ee2869ca 编写于 作者: Y Yibing Liu 提交者: GitHub

Remove redundant infershape in linear chain crf grad, test=develop (#20629)

上级 b4a3b750
......@@ -242,60 +242,14 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
"Input(LogLikelihood@GRAD) shoudl be not null.");
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
"The Input(TransitionExps) should be a 2-D tensor.");
bool check = true;
if ((!ctx->IsRuntime()) &&
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
auto label_dims = ctx->GetInputDim("Label");
if (ctx->HasInput("Length")) {
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3,
"The Input(EmissionExps) should be a 3-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[2], transition_exps_dims[1],
"The 3nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
"The Input(Label) should be a 3-D tensor with the 3nd "
"dimensions fixed to 1.");
} else {
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
"The Input(EmissionExps) should be a 2-D tensor.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
"The Input(Label) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"The Input(Label) 2nd dimensions fixed to 1.");
}
PADDLE_ENFORCE_NE(emission_exps_dims[0], 0,
"An empty mini-batch is not allowed.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same.");
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
if (ctx->HasInput("Length") == false) {
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
}
}
// ctx->SetOutputDim(framework::GradVarName("Emission"),
// emission_exps_dims);
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
ctx->SetOutputDim(framework::GradVarName("Transition"),
transition_exps_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册