diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc
index e127811a101f133802fc9c038d42e843d45d4368..14ae74ab662748c1cc774293680e758bc3cb560e 100644
--- a/paddle/operators/linear_chain_crf_op.cc
+++ b/paddle/operators/linear_chain_crf_op.cc
@@ -17,6 +17,22 @@ limitations under the License. */
 namespace paddle {
 namespace operators {
 
+namespace {
+template <typename T>
+T NormalizeL1(T* x, size_t len) {
+  T sum = 0.;
+  for (size_t i = 0; i < len; ++i) sum += x[i];
+  // (This comment is from the old LinearChainCRFLayer.)
+  // Right now, we just bet that sum won't be zero. If this really happens, we
+  // will figure out what should be done then.
+  PADDLE_ENFORCE(sum,
+                 "The unnormalized probabilites of all possible unfinished "
+                 "sequences must be greater than 0.");
+  for (size_t i = 0; i < len; ++i) x[i] /= sum;
+  return sum;
+}
+}  // namespace
+
 using framework::LoDTensor;
 using framework::LoD;
 
@@ -54,13 +70,25 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
         "each tag value \f$v$\f. This vector is called a forward vecotr and "
         "will also be used in backward computations.")
         .AsIntermediate();
+    AddOutput("EmissionExps",
+              "The exponentials of Input(Emission). This is an intermediate "
+              "computational result in forward computation, and will be reused "
+              "in backward computation.")
+        .AsIntermediate();
+    AddOutput("TransitionExps",
+              "The exponentials of Input(Transition). This is an intermediate "
+              "computational result in forward computation, and will be reused "
+              "in backward computation.")
+        .AsIntermediate();
     AddOutput(
         "LogLikelihood",
-        "(Tensor, default: Tensor<float>). The logarithm of the conditional "
+        "(Tensor, default: Tensor<float>). The logarithm of the "
+        "conditional "
         "likelihood of each training sample in a mini-batch. This is a 2-D "
         "tensor with shape [S x 1], where S is the sequence number in a "
         "mini-batch. "
-        "Note: S is equal to the sequence number in a mini-batch. The output "
+        "Note: S is equal to the sequence number in a mini-batch. The "
+        "output "
         "is no longer a LoDTensor.");
     AddComment(R"DOC(
 Conditional Random Field defines an undirected probabilistic graph with nodes
@@ -129,6 +157,10 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
 
     PADDLE_ENFORCE(ctx->HasOutput("Alpha"),
                    "Output(Alpha) should be not null.");
+    PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"),
+                   "Output(EmissionExps) should be not null.");
+    PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"),
+                   "Output(TransitionExps) should be not null.");
     PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"),
                    "Output(LogLikelihood) should be not null.");
 
@@ -143,7 +175,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE_EQ(
         transition_dims[0] - 2, transition_dims[1],
         "An invalid dimension for the Input(Transition), which should "
-        "be a 2-D tensor with shape [D + 2 x D].");
+        "be a 2-D tensor with shape [(D + 2) x D].");
     PADDLE_ENFORCE_EQ(
         emission_dims[1], transition_dims[1],
         "The 2nd dimension of the Input(Emission) and the Input(Transition) "
@@ -157,11 +189,14 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
         "should be the same.");
 
     ctx->SetOutputDim("Alpha", emission_dims);
-
+    ctx->SetOutputDim("EmissionExps", emission_dims);
+    ctx->SetOutputDim("TransitionExps", transition_dims);
     // (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
     // is the sequence number in a mini-batch. The dimension set here should be
     // resized to its correct size in the function Compute.
     ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
+
+    ctx->ShareLoD("Emission", /*->*/ "EmissionExps");
   }
 
  protected:
@@ -180,9 +215,12 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
   void Compute(const framework::ExecutionContext& ctx) const override {
     PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
                    "This kernel only runs on CPU.");
-
     auto* emission_weights = ctx.Input<LoDTensor>("Emission");
     auto* transition_weights = ctx.Input<Tensor>("Transition");
+    auto* emission_exps = ctx.Output<LoDTensor>("EmissionExps");
+    emission_exps->mutable_data<T>(platform::CPUPlace());
+    auto* transition_exps = ctx.Output<Tensor>("TransitionExps");
+    transition_exps->mutable_data<T>(platform::CPUPlace());
     auto* label = ctx.Input<LoDTensor>("Label");
 
     auto in_lod = emission_weights->lod();
@@ -195,18 +233,29 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
     const size_t level = 0;
 
     auto emission_dims = emission_weights->dims();
+    const size_t batch_size = emission_dims[0];
+    const size_t tag_num = emission_dims[1];
     const size_t seq_num = in_lod[level].size() - 1;
 
-    // TODO(caoying) These local variables seems to be created and destroied
-    // every time this function is called. Will this bring additional overhead?
-    Tensor emission_exps;
     Tensor emission_row_max;
-    Tensor transition_exps;
-    emission_exps.mutable_data<T>(emission_dims, platform::CPUPlace());
     emission_row_max.mutable_data<T>(
-        framework::make_ddim({emission_dims[0], 1}), platform::CPUPlace());
-    transition_exps.mutable_data<T>(transition_weights->dims(),
-                                    platform::CPUPlace());
+        framework::make_ddim({static_cast<int>(batch_size), 1}),
+        platform::CPUPlace());
+
+    auto place = ctx.GetEigenDevice<platform::CPUPlace>();
+    auto x = EigenMatrix<T>::From(*emission_weights);
+    auto x_row_max = EigenMatrix<T>::From(emission_row_max);
+    x_row_max.device(place) =
+        x.maximum(Eigen::DSizes<int, 1>(1))
+            .reshape(Eigen::DSizes<int, 2>(int(batch_size), 1));
+
+    auto x_exps = EigenMatrix<T>::From(*emission_exps);
+    x_exps.device(place) =
+        (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
+
+    auto w = EigenMatrix<T>::From(*transition_weights);
+    auto w_exps = EigenMatrix<T>::From(*transition_exps);
+    w_exps.device(place) = w.exp();
 
     auto* alpha = ctx.Output<LoDTensor>("Alpha");
     alpha->mutable_data<T>(ctx.GetPlace());
@@ -214,117 +263,124 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
     // resize the output tensor to the correct dimension.
     ll->Resize({static_cast<int>(seq_num), 1});
     T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
-
     for (size_t i = 0; i < seq_num; ++i) {
       int start_pos = static_cast<int>(in_lod[level][i]);
       int end_pos = static_cast<int>(in_lod[level][i + 1]);
 
       const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos);
       Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos);
-      Tensor one_seq_exps = emission_exps.Slice<T>(start_pos, end_pos);
+      Tensor one_seq_exps = emission_exps->Slice<T>(start_pos, end_pos);
       const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
       Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
 
       log_likelihood[i] = ForwardOneSequence(
-          ctx.device_context(), one_seq, one_seq_row_max, one_seq_exps,
-          (*transition_weights), transition_exps, one_seq_label, one_seq_alpha);
+          &one_seq, &one_seq_row_max, &one_seq_exps, transition_weights,
+          transition_exps, &one_seq_label, &one_seq_alpha);
     }
   }
 
  protected:
-  T ForwardOneSequence(const platform::DeviceContext& ctx,
-                       const Tensor& emission, Tensor& emission_row_max,
-                       Tensor& emission_exps, const Tensor& trans_weights,
-                       Tensor& trans_weight_exps, const Tensor& label,
-                       Tensor& alpha) const {
-    // (TODO caoying) Evaluate and optimize this.
-    // The Eigen compution kernel will be invoked for multiple times.
-    // Some computations regardless of sequence inforamtion could be performed
-    // only one time for the entire batch. This potentially could be optimized.
-
-    auto x_dims = emission.dims();
+  T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
+                       const Tensor* emission_exps, const Tensor* trans_weights,
+                       const Tensor* trans_weight_exps, const Tensor* label,
+                       Tensor* alpha) const {
+    const T* x = emission->data<T>();
+    const T* x_row_max = emission_row_max->data<T>();
+    const T* x_exps = emission_exps->data<T>();
+    const T* w = trans_weights->data<T>();
+    const T* w_exps = trans_weight_exps->data<T>();
+    T* alpha_value = alpha->data<T>();
+
+    auto x_dims = emission->dims();
     const size_t seq_length = x_dims[0];
     const size_t tag_num = x_dims[1];
-
-    T* alpha_value = alpha.data<T>();
-
-    auto x = EigenMatrix<T>::From(emission);
-    auto x_row_max = EigenMatrix<T>::From(emission_row_max);
-    const int class_dim = 1;
-    x_row_max.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
-        x.maximum(Eigen::DSizes<int, 1>(class_dim))
-            .reshape(Eigen::DSizes<int, 2>(int(seq_length), 1));
-
-    auto x_exps = EigenMatrix<T>::From(emission_exps);
-    x_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
-        (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
-
-    auto w = EigenMatrix<T>::From(trans_weights);
-    auto w_exps = EigenMatrix<T>::From(trans_weight_exps);
-    w_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = w.exp();
     // The 1st row of w are transition weights for start mask.
-    const size_t start_ridx = 0;
     // The 2nd row of w are transition weights for end mask.
-    const size_t end_ridx = 1;
     // Transition weights among other tags begins from the 3rd row of w.
-    const size_t state_base_ridx = 2;
+    const size_t state_trans_base_idx = 2;
 
     for (size_t i = 0; i < tag_num; ++i) {
-      alpha_value[i] = w_exps(start_ridx, i) * x_exps(0, i);
+      alpha_value[i] = w_exps[i] * x_exps[i];
     }
-    T ll = -x_row_max(0, 1) - std::log(NormalizeL1(alpha_value, tag_num));
+    T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num));
 
     for (size_t k = 1; k < seq_length; ++k) {
       for (size_t i = 0; i < tag_num; ++i) {
         T sum = 0.;
         for (size_t j = 0; j < tag_num; ++j) {
           sum += alpha_value[(k - 1) * tag_num + j] *
-                 w_exps(j + state_base_ridx, i);
+                 w_exps[(j + state_trans_base_idx) * tag_num + i];
         }
-        alpha_value[k * tag_num + i] = x_exps(k, i) * sum;
+        alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
       }
-      ll -= x_row_max(k, 1) +
-            std::log(NormalizeL1(alpha_value + k * tag_num, tag_num));
+      ll -= x_row_max[k] +
+            std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
     }
     T sum = 0.;
     for (size_t i = 0; i < tag_num; ++i) {
-      sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps(end_ridx, i);
+      sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
     }
     ll -= std::log(sum);
 
-    const int* lbl = label.data<int>();
+    const int* lbl = label->data<int>();
     PADDLE_ENFORCE_LT(
         *std::max_element(lbl, lbl + seq_length), tag_num,
         "An invalid tag label that execesses the largest tag number.");
-
     // Calculate the nominator part, which depends on the label sequence.
-    ll += w(start_ridx, lbl[0]) + x(start_ridx, lbl[0]) +
-          w(end_ridx, lbl[seq_length - 1]);
+    ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] +
+          w[tag_num + lbl[seq_length - 1]] /*end transition*/;
     for (size_t k = 1; k < seq_length; ++k)
-      ll += x(k, lbl[k]) + w(lbl[k - 1], lbl[k]);
+      ll += x[k * tag_num + lbl[k]] + w[lbl[k - 1] * tag_num + lbl[k]];
     return -ll;
   }
-
- private:
-  T NormalizeL1(T* x, size_t len) const {
-    T sum = 0.;
-    for (size_t i = 0; i < len; ++i) sum += x[i];
-    // (This comment is from the old LinearChainCRFLayer.)
-    // Right now, we just bet that sum won't be zero. If this really happens, we
-    // will figure out what should be done then.
-    PADDLE_ENFORCE(sum,
-                   "The unnormalized probabilites of all possible unfinished "
-                   "sequences must be greater than 0.");
-    for (size_t i = 0; i < len; ++i) x[i] /= sum;
-    return sum;
-  }
 };
 
 class LinearChainCrfGradOp : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
 
-  void InferShape(framework::InferShapeContext* ctx) const override {}
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("EmissionExps"),
+                   "Input(EmissionExps) should be not null.");
+    PADDLE_ENFORCE(ctx->HasInput("TransitionExps"),
+                   "Input(TransitionExps) should be not null.");
+    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
+                   "Input(LogLikelihood@GRAD) shoudl be not null.");
+
+    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Emission")),
+                   "Output(Emission@GRAD) should be not null.");
+    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Transition")),
+                   "Output(Transition@GRAD) should be not null.");
+
+    auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
+    auto transition_exps_dims =
+        ctx->GetInputDim(framework::GradVarName("TransitionExps"));
+    auto label_dims = ctx->GetInputDim("Label");
+
+    PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL,
+                      "The Input(EmissionExps) should be a 2-D tensor.");
+    PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL,
+                      "The Input(TransitionExps) should be a 2-D tensor.");
+    PADDLE_ENFORCE_EQ(
+        transition_exps_dims[0] - 2, transition_exps_dims[1],
+        "An invalid dimension for the Input(TransitionExps), which should "
+        "be a 2-D tensor with shape [(D + 2) x D].");
+    PADDLE_ENFORCE_EQ(
+        emission_exps_dims[1], transition_exps_dims[1],
+        "The 2nd dimension of the Input(EmissionExps) and the "
+        "Input(TransitionExps) should be equal to the tag number.");
+    PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
+                   "The Input(Label) should be a 2-D tensor with the 2nd "
+                   "dimensions fixed to 1.");
+    PADDLE_ENFORCE_EQ(
+        emission_exps_dims[0], label_dims[0],
+        "The height of Input(EmissionExps) and the height of Input(Label) "
+        "should be the same.");
+
+    ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
+    ctx->SetOutputDim(framework::GradVarName("Transition"),
+                      transition_exps_dims);
+  }
 };
 
 template <typename T>
@@ -334,6 +390,134 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
   void Compute(const framework::ExecutionContext& ctx) const override {
     PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
                    "This kernel only runs on CPU.");
+    auto* ll_grad =
+        ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"));
+    auto* label = ctx.Input<LoDTensor>("Label");
+    auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
+    auto* transition_exps = ctx.Input<Tensor>("TransitionExps");
+    auto* alpha = ctx.Input<Tensor>("Alpha");
+
+    auto* emission_grad =
+        ctx.Output<Tensor>(framework::GradVarName("Emission"));
+    emission_grad->mutable_data<T>(platform::CPUPlace());
+
+    auto* trans_grad = ctx.Output<Tensor>(framework::GradVarName("Transition"));
+    if (trans_grad) trans_grad->mutable_data<T>(platform::CPUPlace());
+
+    auto emission_dims = emission_exps->dims();
+
+    // Beta is the memo table used in dynamic programming to calculate the
+    // backwark vectors. For a backward vector i (the i-th row of beta), it
+    // captures the unnormalized probabilities of partial sequences starting at
+    // position i.
+    Tensor beta;
+    beta.mutable_data<T>(emission_dims, platform::CPUPlace());
+
+    auto place = ctx.GetEigenDevice<platform::CPUPlace>();
+    auto x_grad = EigenMatrix<T>::From(*emission_grad);
+    auto out_grad = EigenMatrix<T>::From(*ll_grad);
+    x_grad.device(place) =
+        x_grad * out_grad.broadcast(Eigen::DSizes<int, 2>(1, emission_dims[1]));
+
+    const size_t level = 0;  // currently, only support sequence.
+    auto lod = emission_exps->lod();
+    for (size_t i = 0; i < lod[level].size() - 1; ++i) {
+      int start_pos = static_cast<int>(lod[level][i]);
+      int end_pos = static_cast<int>(lod[level][i + 1]);
+
+      const Tensor one_seq_emission_exps =
+          emission_exps->Slice<T>(start_pos, end_pos);
+      const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
+      const Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
+      Tensor one_seq_beta = beta.Slice<T>(start_pos, end_pos);
+      Tensor one_seq_emission_grad =
+          emission_grad->Slice<T>(start_pos, end_pos);
+
+      BackwardOneSequence(ctx.device_context(), &one_seq_emission_exps,
+                          transition_exps, &one_seq_alpha, &one_seq_label,
+                          &one_seq_beta, trans_grad, &one_seq_emission_grad);
+    }
+  }
+
+ protected:
+  void BackwardOneSequence(const platform::DeviceContext& ctx,
+                           const Tensor* emission_exps,
+                           const Tensor* transition_exps, const Tensor* alpha,
+                           const Tensor* label, Tensor* beta,
+                           Tensor* transition_grad,
+                           Tensor* emission_grad) const {
+    const T* w_exps = transition_exps->data<T>();
+    const T* x_exps = emission_exps->data<T>();
+    const int* label_value = label->data<int>();
+    T* beta_value = beta->data<T>();
+
+    auto x_dims = emission_exps->dims();
+    const size_t seq_length = x_dims[0];
+    const size_t tag_num = x_dims[1];
+    const size_t state_trans_base_idx = 2;
+
+    // Calculate the backwark vectors beta.
+    for (int i = 0; i < tag_num; ++i)
+      beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
+    NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
+
+    for (int k = seq_length - 2; k >= 0; --k) {
+      for (int i = 0; i < tag_num; ++i) {
+        T sum = 0.;
+        for (int j = 0; j < tag_num; ++j) {
+          sum += x_exps[(i + state_trans_base_idx) * tag_num + j] *
+                 beta_value[(k + 1) * tag_num + j] *
+                 x_exps[(k + 1) * tag_num + j];
+        }
+        beta_value[k * tag_num + i] = sum;
+      }
+      NormalizeL1<T>(beta_value + k * tag_num, tag_num);
+    }
+
+    auto alpha_mat = EigenMatrix<T>::From(*alpha);
+    auto beta_mat = EigenMatrix<T>::From(*beta);
+    auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
+
+    auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
+    x_grad_mat.device(*place) = alpha_mat * beta_mat;
+    x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1))
+                      .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
+                      .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
+
+    for (int k = 0; k < seq_length; ++k)
+      x_grad_mat(k, label_value[k]) -= static_cast<T>(1);
+
+    if (transition_grad) {
+      T* trans_grad = transition_grad->data<T>();
+      for (size_t k = 0; k < tag_num; ++k) {
+        trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
+        trans_grad[tag_num + k] +=
+            x_grad_mat(/*to end state*/ seq_length - 1, k);
+      }
+
+      auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
+      beta_mat = beta_mat * x_exps_mat;
+      beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1))
+                      .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
+                      .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
+
+      for (int k = 1; k < seq_length; ++k) {
+        T sum = 0.;
+        for (int i = 0; i < tag_num; ++i) {
+          for (int j = 0; j < tag_num; ++j)
+            sum += x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j);
+        }
+        sum = static_cast<T>(1) / sum;
+        for (int i = 0; i < tag_num; ++i) {
+          for (int j = 0; j < tag_num; ++j) {
+            trans_grad[(i + 2) * tag_num + j] +=
+                sum * x_exps_mat(i, j) * alpha_mat(k - 1, i) * beta_mat(k, j);
+          }
+        }
+        trans_grad[label_value[k - 1] * tag_num + label_value[k]] -=
+            static_cast<T>(1);
+      }
+    }
   }
 };
 
diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h
index a656e233c2c6331affca345283a3c22ee32852e1..e9852de5959fd9ab56bfa62e33fc3acc519c9f3a 100644
--- a/paddle/operators/linear_chain_crf_op.h
+++ b/paddle/operators/linear_chain_crf_op.h
@@ -30,20 +30,24 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override;
 
  protected:
-  T ForwardOneSequence(const platform::DeviceContext& ctx,
-                       const Tensor& emission, Tensor& emission_row_max,
-                       Tensor& emission_exps, const Tensor& trans_weights,
-                       Tensor& trans_weight_exps, const Tensor& label,
-                       Tensor& a) const;
-
- private:
-  T NormalizeL1(T* x, size_t len) const;
+  T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
+                       const Tensor* emission_exps, const Tensor* trans_weights,
+                       const Tensor* trans_weight_exps, const Tensor* label,
+                       Tensor* alpha) const;
 };
 
 template <typename Place, typename T>
 class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override;
+
+ protected:
+  void BackwardOneSequence(const platform::DeviceContext& ctx,
+                           const Tensor* emission_exps,
+                           const Tensor* transition_exps, const Tensor* alpha,
+                           const Tensor* label, Tensor* beta,
+                           Tensor* transition_grad,
+                           Tensor* emission_grad) const;
 };
 
 }  // namespace operators
diff --git a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
index 413210e75b8feeaf76710eb3965a007446aba852..9b73e26eb98dbfac65166a83ab570d244362f2d2 100644
--- a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
+++ b/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py
@@ -4,10 +4,12 @@ import numpy as np
 
 from op_test import OpTest
 
+import pdb
+
 
 class LinearChainCrfForward(object):
-    def __init__(self, seq_start_positions, emission_weights,
-                 transition_weights, labels):
+    def __init__(self, seq_start_positions, emission_weights, emission_row_max,
+                 emission_exps, transition_weights, transition_exps, labels):
         self.tag_num = emission_weights.shape[1]
         self.seq_num = len(seq_start_positions) - 1
 
@@ -15,25 +17,25 @@ class LinearChainCrfForward(object):
         self.labels = labels
         self.x = emission_weights
 
-        self.x_row_max = np.amax(self.x, axis=1, keepdims=True)
-        self.x_exps = np.exp(self.x - self.x_row_max)
+        self.x_row_max = emission_row_max
+        self.x_exps = emission_exps
 
         # unnormalized logits of the transition weights for the start mark.
         self.a = transition_weights[0, :]
-        self.a_exps = np.exp(self.a)
+        self.a_exps = transition_exps[0, :]
         # unnormalized logits of the transition weights for the end mark.
         self.b = transition_weights[1, :]
-        self.b_exps = np.exp(self.b)
+        self.b_exps = transition_exps[1, :]
         # unnormalized logits of the transition weights for all the other tags.
         self.w = transition_weights[2:, :]
-        self.w_exps = np.exp(self.w)
+        self.w_exps = transition_exps[2:, :]
 
         # The output of linear chain crf operator.
         # alpha is a memo table in dynamic programming to caculate
         # nomalization factor.
         self.alpha = np.zeros(
             (seq_start_positions[-1], self.tag_num), dtype="float32")
-        self.log_likelihood = np.zeros((self.tag_num, 1))
+        self.log_likelihood = np.zeros((self.seq_num, 1))
 
     def _l1_norm(self, x):
         s = np.sum(x)
@@ -91,11 +93,15 @@ class TestLinearChainCrfOp(OpTest):
         lod = [[0]]
         for i in range(SEQ_NUM):
             lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
-
         emission = np.random.uniform(-1, 1,
                                      [lod[-1][-1], TAG_NUM]).astype("float32")
+        emission_row_max = np.amax(emission, axis=1, keepdims=True)
+        emission_exps = np.exp(emission - emission_row_max)
+
         transition = np.random.uniform(-0.5, 0.5,
                                        [TAG_NUM + 2, TAG_NUM]).astype("float32")
+        transition_exps = np.exp(transition)
+
         labels = np.random.randint(
             low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
 
@@ -105,10 +111,17 @@ class TestLinearChainCrfOp(OpTest):
             "Label": (labels, lod)
         }
 
-        crf = LinearChainCrfForward(lod[0], emission, transition, labels)
+        crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
+                                    emission_exps, transition, transition_exps,
+                                    labels)
         alpha, log_likelihood = crf.crf_forward_compute()
 
-        self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood}
+        self.outputs = {
+            "Alpha": alpha,
+            "EmissionExps": emission_exps,
+            "TransitionExps": transition_exps,
+            "LogLikelihood": log_likelihood
+        }
 
     def setUp(self):
         self.op_type = "linear_chain_crf"
@@ -117,6 +130,13 @@ class TestLinearChainCrfOp(OpTest):
     def test_check_output(self):
         self.check_output()
 
+    def test_check_grad(self):
+        self.check_grad(["Emission", "Transition"], "LogLikelihood")
+
+    def test_check_grad_ignore_transition(self):
+        self.check_grad(
+            ["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
+
 
 if __name__ == "__main__":
     unittest.main()