diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index d78d496187a1da3aaa7d77f8eef1e4ef46b72ff4..b78a6ceb5199bc39d04e3560d350f1bd1b6aee52 100755 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -242,60 +242,14 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { "Input(LogLikelihood@GRAD) shoudl be not null."); auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); - PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2, - "The Input(TransitionExps) should be a 2-D tensor."); - bool check = true; - if ((!ctx->IsRuntime()) && - (transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ( - transition_exps_dims[0] - 2, transition_exps_dims[1], - "An invalid dimension for the Input(TransitionExps), which should " - "be a 2-D tensor with shape [(D + 2) x D]."); - } - auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); - auto label_dims = ctx->GetInputDim("Label"); - if (ctx->HasInput("Length")) { - PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3, - "The Input(EmissionExps) should be a 3-D tensor."); - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[2], transition_exps_dims[1], - "The 3nd dimension of the Input(EmissionExps) and the " - "Input(TransitionExps) should be equal to the tag number."); - PADDLE_ENFORCE_EQ(label_dims.size(), 3, - "The Input(Label) should be a 3-D tensor with the 3nd " - "dimensions fixed to 1."); - } else { - PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2, - "The Input(EmissionExps) should be a 2-D tensor."); - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[1], transition_exps_dims[1], - "The 2nd dimension of the Input(EmissionExps) and the " - "Input(TransitionExps) should be equal to the tag number."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, - "The Input(Label) should be a 2-D tensor"); - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "The Input(Label) 2nd dimensions fixed to 1."); - } - PADDLE_ENFORCE_NE(emission_exps_dims[0], 0, - "An empty mini-batch is not allowed."); - - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[0], label_dims[0], - "The height of Input(EmissionExps) and the height of Input(Label) " - "should be the same."); - if (ctx->HasOutput(framework::GradVarName("Emission"))) { ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); if (ctx->HasInput("Length") == false) { ctx->ShareLoD("Emission", framework::GradVarName("Emission")); } } - // ctx->SetOutputDim(framework::GradVarName("Emission"), - // emission_exps_dims); + if (ctx->HasOutput(framework::GradVarName("Transition"))) { ctx->SetOutputDim(framework::GradVarName("Transition"), transition_exps_dims);