diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu
index 75417cced237c48dda1f6e87c0647b10a66d0907..b564f990b4920a3a01b6ce0dd53e8f5e5d0464aa 100644
--- a/paddle/fluid/operators/math/gru_compute.cu
+++ b/paddle/fluid/operators/math/gru_compute.cu
@@ -30,25 +30,31 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
     dim3 threads;
     dim3 grid;
     if (batch_size == 1) {
-      constexpr int tiled_size = 16;
-      int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
-      threads = dim3(tiled_size, 1);
-      grid = dim3(frame_blocks, 1);
-
-      detail::KeFastCollectiveGruGate<T,
-                                      tiled_size><<<grid, threads, 0, stream>>>(
-          value.gate_value, value.prev_out_value, value.gate_weight,
-          value.reset_output_value, frame_size, active_gate);
-
-      frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
-      grid = dim3(frame_blocks, 1);
-      detail::KeFastCollectiveGruOut<T,
-                                     tiled_size><<<grid, threads, 0, stream>>>(
-          value.state_weight, value.prev_out_value, value.output_value,
-          value.gate_value, value.reset_output_value, frame_size, active_node,
-          origin_mode);
-
-      return;
+      if (context.GetComputeCapability() >= 70) {
+        constexpr int tiled_size = 16;
+        int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
+        threads = dim3(tiled_size, 1);
+        grid = dim3(frame_blocks, 1);
+        detail::KeFastCollectiveGruGate<
+            T, tiled_size><<<grid, threads, 0, stream>>>(
+            value.gate_value, value.prev_out_value, value.gate_weight,
+            value.reset_output_value, frame_size, active_gate);
+
+        frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
+        grid = dim3(frame_blocks, 1);
+        detail::KeFastCollectiveGruOut<
+            T, tiled_size><<<grid, threads, 0, stream>>>(
+            value.state_weight, value.prev_out_value, value.output_value,
+            value.gate_value, value.reset_output_value, frame_size, active_node,
+            origin_mode);
+
+        return;
+      } else {
+        int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
+        int frame_blocks = (frame_size + 1024 - 1) / 1024;
+        threads = dim3(frame_per_block, 1);
+        grid = dim3(frame_blocks, 1);
+      }
     } else {
       threads = dim3(32, 32);
       grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);