diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc
index 011d45c396579a26a804a4cf2ecd50734e7df945..cc3fbd587668b17b7edde50b157adca83e81eddc 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cc
+++ b/paddle/fluid/operators/math/sequence_pooling.cc
@@ -37,18 +37,23 @@ class MaxSeqPoolFunctor {
  public:
   void operator()(const platform::CPUDeviceContext& context,
                   const framework::LoDTensor& input, T pad_value,
-                  framework::Tensor* output, framework::Tensor* index) {
+                  framework::LoDTensor* output, framework::Tensor* index) {
     auto in_dims = input.dims();
     auto out_dims = output->dims();
     auto idx_dims = index->dims();
-    PADDLE_ENFORCE_GT(in_dims.size(), 1);
-    PADDLE_ENFORCE_GT(out_dims.size(), 1);
+    PADDLE_ENFORCE_GT(in_dims.size(), 1,
+                      "The rank of input shall be greater than 1.");
+    PADDLE_ENFORCE_GT(out_dims.size(), 1,
+                      "The rank of output shall be greater than 1.");
     for (int64_t i = 1; i < in_dims.size(); ++i) {
-      PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
+      PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i],
+                        "The dimension of input and output shall be same.");
     }
-    PADDLE_ENFORCE_EQ(idx_dims, out_dims);
+    PADDLE_ENFORCE_EQ(idx_dims, out_dims,
+                      "The dimension of index and output shall be same.");
 
-    auto starts = input.lod()[0];
+    auto lod_level = input.lod().size();
+    auto starts = input.lod()[lod_level - 1];
     const T* in_data = input.data<T>();
     T* out_data = output->data<T>();
     int* max_index = index->data<int>();
@@ -85,16 +90,20 @@ class MaxSeqPoolFunctor<T, true> {
  public:
   void operator()(const platform::CPUDeviceContext& context,
                   const framework::LoDTensor& input, T pad_value,
-                  framework::Tensor* output, framework::Tensor* index) {
+                  framework::LoDTensor* output, framework::Tensor* index) {
     auto in_dims = input.dims();
     auto out_dims = output->dims();
-    PADDLE_ENFORCE_GT(in_dims.size(), 1);
-    PADDLE_ENFORCE_GT(out_dims.size(), 1);
+    PADDLE_ENFORCE_GT(in_dims.size(), 1,
+                      "The rank of input shall be greater than 1.");
+    PADDLE_ENFORCE_GT(out_dims.size(), 1,
+                      "The rank of output shall be greater than 1.");
     for (int64_t i = 1; i < in_dims.size(); ++i) {
-      PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
+      PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i],
+                        "The dimension of input and output shall be same.");
     }
 
-    auto starts = input.lod()[0];
+    auto lod_level = input.lod().size();
+    auto starts = input.lod()[lod_level - 1];
     const T* in_data = input.data<T>();
     T* out_data = output->data<T>();
 
@@ -123,18 +132,23 @@ template <typename T>
 class MaxSeqPoolGradFunctor {
  public:
   void operator()(const platform::CPUDeviceContext& context,
-                  const framework::Tensor& out_grad,
+                  const framework::LoDTensor& out_grad,
                   const framework::Tensor& index,
                   framework::LoDTensor* in_grad) {
     auto og_dims = out_grad.dims();
     auto ig_dims = in_grad->dims();
     auto idx_dims = index.dims();
-    PADDLE_ENFORCE_GT(og_dims.size(), 1);
-    PADDLE_ENFORCE_GT(ig_dims.size(), 1);
+    PADDLE_ENFORCE_GT(og_dims.size(), 1,
+                      "The rank of output@Grad shall be greater than 1.");
+    PADDLE_ENFORCE_GT(ig_dims.size(), 1,
+                      "The rank of input@Grad shall be greater than 1.");
     for (int64_t i = 1; i < og_dims.size(); ++i) {
-      PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
+      PADDLE_ENFORCE_EQ(
+          og_dims[i], ig_dims[i],
+          "The dimension of input@Grad and output@Grad shall be same.");
     }
-    PADDLE_ENFORCE_EQ(idx_dims, og_dims);
+    PADDLE_ENFORCE_EQ(idx_dims, og_dims,
+                      "The dimension of index and output@Grad shall be same.");
 
     const T* og_data = out_grad.data<T>();
     const int* max_index = index.data<int>();
@@ -159,14 +173,15 @@ class LastSeqPoolFunctor {
  public:
   void operator()(const platform::CPUDeviceContext& context,
                   const framework::LoDTensor& input, T pad_value,
-                  framework::Tensor* output) {
+                  framework::LoDTensor* output) {
     // Create pointers to input and output data
     auto* in_data = input.data<T>();
     auto* out_data = output->data<T>();
 
     // Calculate the size of each item in sequence
     int64_t item_size = input.numel() / input.dims()[0];
-    auto lod = input.lod()[0];
+    auto lod_level = input.lod().size();
+    auto lod = input.lod()[lod_level - 1];
     int seq_num = static_cast<int>(lod.size()) - 1;
     for (int i = 0; i < seq_num; ++i) {
       // Calculate the length of each sequence
@@ -191,14 +206,15 @@ class FirstSeqPoolFunctor {
  public:
   void operator()(const platform::CPUDeviceContext& context,
                   const framework::LoDTensor& input, T pad_value,
-                  framework::Tensor* output) {
+                  framework::LoDTensor* output) {
     // Create pointers to input and output data
     auto* in_data = input.data<T>();
     auto* out_data = output->data<T>();
 
     // Calculate the size of each item in sequence
     int64_t item_size = input.numel() / input.dims()[0];
-    auto lod = input.lod()[0];
+    auto lod_level = input.lod().size();
+    auto lod = input.lod()[lod_level - 1];
     int seq_num = static_cast<int>(lod.size()) - 1;
     for (int i = 0; i < seq_num; ++i) {
       // Calculate the length of each sequence
@@ -222,12 +238,15 @@ template <typename T>
 class SumSeqPoolGradFunctor {
  public:
   void operator()(const platform::CPUDeviceContext& context,
-                  const framework::Tensor& out_grad,
+                  const framework::LoDTensor& out_grad,
                   framework::LoDTensor* in_grad) {
-    auto lod = in_grad->lod()[0];
+    auto lod_level = in_grad->lod().size();
+    auto lod = in_grad->lod()[lod_level - 1];
     int64_t out_w = out_grad.numel() / out_grad.dims()[0];
     int64_t in_w = in_grad->numel() / in_grad->dims()[0];
-    PADDLE_ENFORCE(in_w == out_w);
+    PADDLE_ENFORCE_EQ(
+        in_w, out_w,
+        "The feature size of input@Grad and output@Grad shall be same.");
     const T* out_g_data = out_grad.data<T>();
     T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
     auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
@@ -250,8 +269,9 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
   /* max pool has index output */
   void operator()(const platform::CPUDeviceContext& context,
                   const std::string pooltype, T pad_value,
-                  const framework::LoDTensor& input, framework::Tensor* output,
-                  bool is_test, framework::Tensor* index = nullptr) {
+                  const framework::LoDTensor& input,
+                  framework::LoDTensor* output, bool is_test,
+                  framework::Tensor* index = nullptr) {
     if (pooltype == "MAX") {
       if (is_test) {
         math::MaxSeqPoolFunctor<T, true> max_pool;
@@ -272,11 +292,13 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
       first_pool(context, input, pad_value, output);
       return;
     }
-
-    auto lod = input.lod()[0];
+    auto lod_level = input.lod().size();
+    auto lod = input.lod()[lod_level - 1];
     if (pooltype == "SUM") {
       auto place = context.GetPlace();
-      PADDLE_ENFORCE(platform::is_cpu_place(place));
+      PADDLE_ENFORCE_EQ(
+          platform::is_cpu_place(place), true,
+          "Sequence_pool should run on CPU Device when pooltype is SUM");
       const T* src = input.data<T>();
       T* dst = output->mutable_data<T>(place);
       jit::seq_pool_attr_t attr(
@@ -330,7 +352,8 @@ template <typename T>
 class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
  public:
   void operator()(const platform::CPUDeviceContext& context,
-                  const std::string pooltype, const framework::Tensor& out_grad,
+                  const std::string pooltype,
+                  const framework::LoDTensor& out_grad,
                   framework::LoDTensor* in_grad,
                   /* max pool has index */
                   const framework::Tensor* index = nullptr) {
@@ -352,7 +375,8 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
       return;
     }
 
-    auto lod = in_grad->lod()[0];
+    auto lod_level = in_grad->lod().size();
+    auto lod = in_grad->lod()[lod_level - 1];
     auto& place = *context.eigen_device();
     for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
       if (lod[i] == lod[i + 1]) continue;
diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu
index 4de99ba677d5108e8b70e71e3dfefa17b6e18beb..91545131e4cbb5d6dcae9c111e97598ee54cc898 100644
--- a/paddle/fluid/operators/math/sequence_pooling.cu
+++ b/paddle/fluid/operators/math/sequence_pooling.cu
@@ -159,9 +159,11 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
  public:
   void operator()(const platform::CUDADeviceContext& context,
                   const std::string pooltype, T pad_value,
-                  const framework::LoDTensor& input, framework::Tensor* output,
-                  bool is_test, framework::Tensor* index = nullptr) {
-    auto& lod = input.lod()[0];
+                  const framework::LoDTensor& input,
+                  framework::LoDTensor* output, bool is_test,
+                  framework::Tensor* index = nullptr) {
+    auto lod_level = input.lod().size();
+    auto& lod = input.lod()[lod_level - 1];
     const size_t item_dim = output->numel() / output->dims()[0];
     dim3 threads(1024, 1);
     dim3 grid(lod.size(), 1);
@@ -319,11 +321,13 @@ template <typename T>
 class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
  public:
   void operator()(const platform::CUDADeviceContext& context,
-                  const std::string pooltype, const framework::Tensor& out_grad,
+                  const std::string pooltype,
+                  const framework::LoDTensor& out_grad,
                   framework::LoDTensor* in_grad,
                   /* max pool has index */
                   const framework::Tensor* index = nullptr) {
-    auto& lod = in_grad->lod()[0];
+    auto lod_level = in_grad->lod().size();
+    auto& lod = in_grad->lod()[lod_level - 1];
     const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
     dim3 threads(1024, 1);
     dim3 grid(lod.size(), 1);
diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h
index 1dc02eae201413b9483b31129578be144f175aa3..847d0bca951a7e54a74a6c803a4f24d50672228f 100644
--- a/paddle/fluid/operators/math/sequence_pooling.h
+++ b/paddle/fluid/operators/math/sequence_pooling.h
@@ -28,7 +28,7 @@ class SequencePoolFunctor {
   /* max pool has index output */
   void operator()(const DeviceContext& context, const std::string pooltype,
                   T pad_value, const framework::LoDTensor& input,
-                  framework::Tensor* output, bool is_test = false,
+                  framework::LoDTensor* output, bool is_test = false,
                   framework::Tensor* index = nullptr);
 };
 
@@ -36,7 +36,7 @@ template <typename DeviceContext, typename T>
 class SequencePoolGradFunctor {
  public:
   void operator()(const DeviceContext& context, const std::string pooltype,
-                  const framework::Tensor& out_grad,
+                  const framework::LoDTensor& out_grad,
                   framework::LoDTensor* in_grad,
                   /* max pool has index */
                   const framework::Tensor* index = nullptr);
diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
index f3193fdc55609ee0cc608367c654b9d506217b6c..51e354dcd175845c3db2cce78dac6039361aed08 100644
--- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
+++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
@@ -24,14 +24,15 @@ class SequencePoolOp : public framework::OperatorWithKernel {
   using framework::OperatorWithKernel::OperatorWithKernel;
 
   void InferShape(framework::InferShapeContext* ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput("X"),
-                   "Input(X) of SequencePoolOp should not be null.");
-    PADDLE_ENFORCE(ctx->HasOutput("Out"),
-                   "Output(Out) of SequencePoolOp should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
+                      "Input(X) of SequencePoolOp should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
+                      "Output(Out) of SequencePoolOp should not be null.");
     ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
     if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
-      PADDLE_ENFORCE(ctx->HasOutput("MaxIndex"),
-                     "Output(MaxIndex) of SequencePoolOp should not be null.");
+      PADDLE_ENFORCE_EQ(
+          ctx->HasOutput("MaxIndex"), true,
+          "Output(MaxIndex) of SequencePoolOp should not be null.");
       ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
     }
   }
@@ -102,9 +103,10 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
   using framework::OperatorWithKernel::OperatorWithKernel;
 
   void InferShape(framework::InferShapeContext* ctx) const override {
-    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
-                   "Gradient of Out should not be null.");
-    PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
+                      "Gradient of Out should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
+                      "The input X should not be null.");
     auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
     auto x_dims = ctx->GetInputDim("X");
     PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h
index c32734808c39313fcf0a0e624d246f2e52838edf..3eec4df121046e6c269cd950234c06b31b57d5a2 100644
--- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h
+++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h
@@ -30,19 +30,30 @@ class SequencePoolKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
     auto* in = context.Input<LoDTensor>("X");
-    auto* out = context.Output<Tensor>("Out");
+    auto* out = context.Output<LoDTensor>("Out");
     std::string pooltype = context.Attr<std::string>("pooltype");
     T pad_value = static_cast<T>(context.Attr<float>("pad_value"));
 
     auto dims = in->dims();
     auto lod = in->lod();
+    auto lod_level = lod.size();
     // InferShape by lod
-    PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
+    PADDLE_ENFORCE_GE(lod_level, 1UL,
+                      "The lod level of input shall be 1 at least.");
+    PADDLE_ENFORCE_LE(lod_level, 2UL,
+                      "The lod level of input shall be no more than 2.");
     PADDLE_ENFORCE_GE(
         dims[0],
-        /*batch size = */ static_cast<int64_t>(lod[0].size() - 1),
+        /*batch size = */ static_cast<int64_t>(lod[lod_level - 1].size() - 1),
         "The first dimension of Input(X) must be large than batch size.");
-    dims[0] = lod[0].size() - 1;
+    if (lod_level > 1UL) {
+      PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1], lod[1].size() - 1,
+                        "The input lod information is illegal.");
+      framework::LoD out_lod;
+      out_lod.push_back(lod[0]);
+      out->set_lod(out_lod);
+    }
+    dims[0] = lod[lod_level - 1].size() - 1;
     out->Resize({dims});
     out->mutable_data<T>(context.GetPlace());
     Tensor* index = nullptr;
@@ -68,7 +79,7 @@ template <typename DeviceContext, typename T>
 class SequencePoolGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
+    auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
     auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
     std::string pooltype = context.Attr<std::string>("pooltype");
     const Tensor* index = nullptr;
diff --git a/python/paddle/fluid/tests/unittests/test_seq_pool.py b/python/paddle/fluid/tests/unittests/test_seq_pool.py
index aa801b1f5d8c7e7c8acec7096db7010a058451ff..2de5d0345912ace44858de1be52dece846ef879a 100644
--- a/python/paddle/fluid/tests/unittests/test_seq_pool.py
+++ b/python/paddle/fluid/tests/unittests/test_seq_pool.py
@@ -21,30 +21,33 @@ from test_reorder_lod_tensor import convert_to_offset
 
 
 def compute_seqpool_sum(x, offset, out, pad_value=0.0):
-    for i in range(len(offset[0]) - 1):
-        if offset[0][i] == offset[0][i + 1]:
+    level = len(offset) - 1
+    for i in range(len(offset[level]) - 1):
+        if offset[level][i] == offset[level][i + 1]:
             out[i] = pad_value
         else:
-            sub_x = x[offset[0][i]:offset[0][i + 1], :]
+            sub_x = x[offset[level][i]:offset[level][i + 1], :]
             out[i] = sub_x.sum(axis=0)
 
 
 def compute_seqpool_avg(x, offset, out, pad_value=0.0):
-    for i in range(len(offset[0]) - 1):
-        if offset[0][i] == offset[0][i + 1]:
+    level = len(offset) - 1
+    for i in range(len(offset[level]) - 1):
+        if offset[level][i] == offset[level][i + 1]:
             out[i] = pad_value
         else:
-            sub_x = x[offset[0][i]:offset[0][i + 1], :]
+            sub_x = x[offset[level][i]:offset[level][i + 1], :]
             out[i] = sub_x.mean(axis=0)
 
 
 def compute_seqpool_sqrt(x, offset, out, pad_value=0.0):
-    for i in range(len(offset[0]) - 1):
-        if offset[0][i] == offset[0][i + 1]:
+    level = len(offset) - 1
+    for i in range(len(offset[level]) - 1):
+        if offset[level][i] == offset[level][i + 1]:
             out[i] = pad_value
         else:
-            sub_x = x[offset[0][i]:offset[0][i + 1], :]
-            seq_len = offset[0][i + 1] - offset[0][i]
+            sub_x = x[offset[level][i]:offset[level][i + 1], :]
+            seq_len = offset[level][i + 1] - offset[level][i]
             out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len)
 
 
@@ -56,9 +59,10 @@ class TestSeqAvgPool(OpTest):
         self.op_type = 'sequence_pool'
         x = np.random.uniform(0.1, 1, [11, 23]).astype('float32')
         lod = self.set_lod()
+        level = len(lod) - 1
         self.inputs = {'X': (x, lod)}
         offset = convert_to_offset(lod)
-        out = np.zeros((len(lod[0]), 23)).astype('float32')
+        out = np.zeros((len(lod[level]), 23)).astype('float32')
         self.outputs = {'Out': out}
         return x, offset, out
 
@@ -69,14 +73,18 @@ class TestSeqAvgPool(OpTest):
     def setUp(self):
         x, offset, out = self.set_data()
         self.compute(x, offset, out)
+        if len(offset) > 1:
+            self.outputs = {'Out': (out, [self.set_lod()[0]])}
 
     def test_check_output(self):
         self.check_output()
 
     def test_check_grad(self):
         # Remove MaxIndex after check_grad is refined.
+        out = self.outputs['Out']
+        if isinstance(out, tuple): out = out[0]
         self.outputs['MaxIndex'] = \
-            np.zeros(self.outputs['Out'].shape).astype('int32')
+            np.zeros(out.shape).astype('int32')
         self.check_grad(["X"], "Out")
 
 
@@ -85,6 +93,11 @@ class TestSeqAvgPoolLen0(TestSeqAvgPool):
         return [[0, 4, 0, 7, 0]]
 
 
+class TestSeqAvgPoolLen0LoDLevel2(TestSeqAvgPool):
+    def set_lod(self):
+        return [[2, 0, 1, 2], [0, 4, 0, 7, 0]]
+
+
 class TestSeqSumPool(TestSeqAvgPool):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.1, 'pooltype': "SUM"}
@@ -96,6 +109,11 @@ class TestSeqSumPoolLen0(TestSeqSumPool):
         return [[0, 4, 0, 7, 0]]
 
 
+class TestSeqSumPoolLen0LoDLevel2(TestSeqSumPool):
+    def set_lod(self):
+        return [[2, 0, 1, 2], [0, 4, 0, 7, 0]]
+
+
 class TestSeqMaxPool(TestSeqAvgPool):
     def set_lod(self):
         return [[13]]
@@ -104,25 +122,27 @@ class TestSeqMaxPool(TestSeqAvgPool):
         self.op_type = 'sequence_pool'
         x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
         lod = self.set_lod()
+        level = len(lod) - 1
         offset = convert_to_offset(lod)
-        for i in range(len(offset[0]) - 1):
-            l = offset[0][i + 1] - offset[0][i]
+        for i in range(len(offset[level]) - 1):
+            l = offset[level][i + 1] - offset[level][i]
             if l > 0:
-                x[offset[0][i] + np.random.randint(l), :] += 2.0
+                x[offset[level][i] + np.random.randint(l), :] += 2.0
 
         self.inputs = {'X': (x, lod)}
 
-        out = np.zeros((len(lod[0]), 23)).astype('float32')
+        out = np.zeros((len(lod[level]), 23)).astype('float32')
         self.outputs = {'Out': out}
         return x, offset, out
 
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.5, 'pooltype': "MAX"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"]
             else:
-                sub_x = x[offset[0][i]:offset[0][i + 1], :]
+                sub_x = x[offset[level][i]:offset[level][i + 1], :]
                 out[i] = np.amax(sub_x, axis=0)
 
 
@@ -131,6 +151,11 @@ class TestSeqMaxPoolLen0(TestSeqMaxPool):
         return [[0, 1, 1, 5, 6, 0]]
 
 
+class TestSeqMaxPoolLen0LoDLevel2(TestSeqMaxPool):
+    def set_lod(self):
+        return [[2, 0, 3, 1], [0, 1, 1, 5, 6, 0]]
+
+
 class TestSeqSqrtPool(TestSeqAvgPool):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"}
@@ -142,14 +167,20 @@ class TestSeqSqrtPoolLen0(TestSeqSqrtPool):
         return [[0, 7, 0, 2, 2, 0]]
 
 
+class TestSeqSqrtPoolLen0LoDLevel2(TestSeqSqrtPool):
+    def set_lod(self):
+        return [[1, 2, 0, 3], [0, 7, 0, 2, 2, 0]]
+
+
 class TestSeqLastPool(TestSeqAvgPool):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"]
             else:
-                sub_x = x[offset[0][i]:offset[0][i + 1], :]
+                sub_x = x[offset[level][i]:offset[level][i + 1], :]
                 out[i] = sub_x[-1, :]
 
 
@@ -158,14 +189,20 @@ class TestSeqLastPoolLen0(TestSeqLastPool):
         return [[0, 3, 4, 0, 4, 0]]
 
 
+class TestSeqLastPoolLen0LoDLevel2(TestSeqLastPool):
+    def set_lod(self):
+        return [[1, 0, 2, 3], [0, 3, 4, 0, 4, 0]]
+
+
 class TestSeqFirstPool(TestSeqAvgPool):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.3, 'pooltype': "FIRST"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"]
             else:
-                sub_x = x[offset[0][i]:offset[0][i + 1], :]
+                sub_x = x[offset[level][i]:offset[level][i + 1], :]
                 out[i] = sub_x[0, :]
 
 
@@ -174,6 +211,11 @@ class TestSeqFirstPoolLen0(TestSeqFirstPool):
         return [[0, 2, 0, 3, 6, 0]]
 
 
+class TestSeqFirstPoolLen0LoDLevel2(TestSeqFirstPool):
+    def set_lod(self):
+        return [[1, 0, 2, 3], [0, 2, 0, 3, 6, 0]]
+
+
 class TestSeqAvgPool2D(TestSeqAvgPool):
     def set_lod(self):
         return [[4, 1, 3, 5]]
@@ -182,20 +224,22 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
         self.op_type = 'sequence_pool'
         x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32')
         lod = self.set_lod()
+        level = len(lod) - 1
         self.inputs = {'X': (x, lod)}
         offset = convert_to_offset(lod)
 
-        out = np.zeros((len(lod[0]), 3, 17)).astype('float32')
+        out = np.zeros((len(lod[level]), 3, 17)).astype('float32')
         self.outputs = {'Out': out}
         return x, offset, out
 
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "AVERAGE"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 17))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 17))
                 out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
 
@@ -205,14 +249,20 @@ class TestSeqAvgPool2DLen0(TestSeqAvgPool2D):
         return [[0, 5, 0, 8, 0]]
 
 
+class TestSeqAvgPool2DLen0LoDLevel2(TestSeqAvgPool2D):
+    def set_lod(self):
+        return [[1, 0, 4], [0, 5, 0, 8, 0]]
+
+
 class TestSeqSumPool2D(TestSeqAvgPool2D):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.2, 'pooltype': "SUM"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 17))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 17))
                 out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
 
@@ -222,23 +272,32 @@ class TestSeqSumPool2DLen0(TestSeqSumPool2D):
         return [[0, 8, 0, 5, 0]]
 
 
+class TestSeqSumPool2DLen0LoDLevel2(TestSeqSumPool2D):
+    def set_lod(self):
+        return [[1, 0, 4], [0, 8, 0, 5, 0]]
+
+
 class TestSeqSqrtPool2D(TestSeqAvgPool2D):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 17))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 17))
-                seq_len = offset[0][i + 1] - offset[0][i]
+                seq_len = offset[level][i + 1] - offset[level][i]
                 out[i] = np.reshape(
                     sub_x.sum(axis=0) / np.sqrt(seq_len), (3, 17))
 
     def test_check_grad(self):
         # Remove MaxIndex after check_grad is refined.
+        out = self.outputs['Out']
+        if isinstance(out, tuple):
+            out = out[0]
         self.outputs['MaxIndex'] = \
-            np.zeros(self.outputs['Out'].shape).astype('int32')
+            np.zeros(out.shape).astype('int32')
         self.check_grad(["X"], "Out", max_relative_error=0.06)
 
 
@@ -247,6 +306,11 @@ class TestSeqSqrtPool2DLen0(TestSeqSqrtPool2D):
         return [[0, 8, 0, 5, 0]]
 
 
+class TestSeqSqrtPool2DLen0LoDLevel2(TestSeqSqrtPool2D):
+    def set_lod(self):
+        return [[1, 0, 2, 2], [0, 8, 0, 5, 0]]
+
+
 class TestSeqMaxPool2D(TestSeqAvgPool2D):
     def set_lod(self):
         return [[4, 1, 3, 5]]
@@ -255,25 +319,27 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
         self.op_type = 'sequence_pool'
         x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32')
         self.lod = self.set_lod()
+        level = len(self.lod) - 1
         self.inputs = {'X': (x, self.lod)}
         offset = convert_to_offset(self.lod)
-        for i in range(len(offset[0]) - 1):
-            l = offset[0][i + 1] - offset[0][i]
+        for i in range(len(offset[level]) - 1):
+            l = offset[level][i + 1] - offset[level][i]
             if l == 0:
                 continue
-            x[offset[0][i] + np.random.randint(l), :] += 1.0
+            x[offset[level][i] + np.random.randint(l), :] += 1.0
 
-        out = np.zeros((len(self.lod[0]), 3, 11)).astype('float32')
+        out = np.zeros((len(self.lod[level]), 3, 11)).astype('float32')
         self.outputs = {'Out': out}
         return x, offset, out
 
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "MAX"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 11))
                 continue
-            sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+            sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                (-1, 3 * 11))
             out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
 
@@ -283,14 +349,20 @@ class TestSeqMaxPool2DLen0(TestSeqMaxPool2D):
         return [[0, 3, 0, 10, 0]]
 
 
+class TestSeqMaxPool2DLen0LoDLevel2(TestSeqMaxPool2D):
+    def set_lod(self):
+        return [[1, 0, 2, 2], [0, 3, 0, 10, 0]]
+
+
 class TestSeqMaxPool2DInference(TestSeqMaxPool2D):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 1.0, 'pooltype': "MAX", 'is_test': True}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 11))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 11))
                 out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
 
@@ -305,14 +377,20 @@ class TestSeqMaxPool2DInferenceLen0(TestSeqMaxPool2DInference):
         return [[0, 3, 0, 10, 0]]
 
 
+class TestSeqMaxPool2DInferenceLen0LoDLevel2(TestSeqMaxPool2DInference):
+    def set_lod(self):
+        return [[1, 0, 2, 2], [0, 3, 0, 10, 0]]
+
+
 class TestSeqLastPool2D(TestSeqAvgPool2D):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 17))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 17))
                 out[i] = np.reshape(sub_x[-1, :], (3, 17))
 
@@ -322,14 +400,20 @@ class TestSeqLastPool2DLen0(TestSeqLastPool2D):
         return [[0, 3, 0, 1, 9, 0]]
 
 
+class TestSeqLastPool2DLen0LoDLevel2(TestSeqLastPool2D):
+    def set_lod(self):
+        return [[1, 0, 2, 3], [0, 3, 0, 1, 9, 0]]
+
+
 class TestSeqFirstPool2D(TestSeqAvgPool2D):
     def compute(self, x, offset, out):
         self.attrs = {"pad_value": 0.0, 'pooltype': "FIRST"}
-        for i in range(len(offset[0]) - 1):
-            if offset[0][i] == offset[0][i + 1]:
+        level = len(offset) - 1
+        for i in range(len(offset[level]) - 1):
+            if offset[level][i] == offset[level][i + 1]:
                 out[i] = self.attrs["pad_value"] * np.ones((3, 17))
             else:
-                sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
+                sub_x = np.reshape(x[offset[level][i]:offset[level][i + 1], :],
                                    (-1, 3 * 17))
                 out[i] = np.reshape(sub_x[0, :], (3, 17))
 
@@ -339,5 +423,10 @@ class TestSeqFirstPool2DLen0(TestSeqFirstPool2D):
         return [[0, 3, 0, 3, 7, 0]]
 
 
+class TestSeqFirstPool2DLen0LoDLevel2(TestSeqFirstPool2D):
+    def set_lod(self):
+        return [[1, 0, 2, 3], [0, 3, 0, 3, 7, 0]]
+
+
 if __name__ == '__main__':
     unittest.main()