From ee2869cae96e032610800ba1f740b6713b276ed7 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 15 Oct 2019 13:49:53 +0800 Subject: [PATCH] Remove redundant infershape in linear chain crf grad, test=develop (#20629) --- paddle/fluid/operators/linear_chain_crf_op.cc | 48 +------------------ 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index d78d496187a..b78a6ceb519 100755 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -242,60 +242,14 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { "Input(LogLikelihood@GRAD) shoudl be not null."); auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); - PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2, - "The Input(TransitionExps) should be a 2-D tensor."); - bool check = true; - if ((!ctx->IsRuntime()) && - (transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ( - transition_exps_dims[0] - 2, transition_exps_dims[1], - "An invalid dimension for the Input(TransitionExps), which should " - "be a 2-D tensor with shape [(D + 2) x D]."); - } - auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); - auto label_dims = ctx->GetInputDim("Label"); - if (ctx->HasInput("Length")) { - PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3, - "The Input(EmissionExps) should be a 3-D tensor."); - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[2], transition_exps_dims[1], - "The 3nd dimension of the Input(EmissionExps) and the " - "Input(TransitionExps) should be equal to the tag number."); - PADDLE_ENFORCE_EQ(label_dims.size(), 3, - "The Input(Label) should be a 3-D tensor with the 3nd " - "dimensions fixed to 1."); - } else { - PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2, - "The Input(EmissionExps) should be a 2-D tensor."); - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[1], transition_exps_dims[1], - "The 2nd dimension of the Input(EmissionExps) and the " - "Input(TransitionExps) should be equal to the tag number."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, - "The Input(Label) should be a 2-D tensor"); - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "The Input(Label) 2nd dimensions fixed to 1."); - } - PADDLE_ENFORCE_NE(emission_exps_dims[0], 0, - "An empty mini-batch is not allowed."); - - PADDLE_INFERSHAPE_ENFORCE_EQ( - ctx, emission_exps_dims[0], label_dims[0], - "The height of Input(EmissionExps) and the height of Input(Label) " - "should be the same."); - if (ctx->HasOutput(framework::GradVarName("Emission"))) { ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); if (ctx->HasInput("Length") == false) { ctx->ShareLoD("Emission", framework::GradVarName("Emission")); } } - // ctx->SetOutputDim(framework::GradVarName("Emission"), - // emission_exps_dims); + if (ctx->HasOutput(framework::GradVarName("Transition"))) { ctx->SetOutputDim(framework::GradVarName("Transition"), transition_exps_dims); -- GitLab