diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h
index 22ba8254cdc2c2dcf0d668639ab110fc06c94622..9c71cce770f0e6925c7e855063a9c6796976f04b 100644
--- a/paddle/fluid/operators/fake_quantize_op.cu.h
+++ b/paddle/fluid/operators/fake_quantize_op.cu.h
@@ -590,20 +590,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
                                                            const T *scale,
                                                            const int bin_cnt,
                                                            const int round_type,
-                                                           const int n,
-                                                           const int c,
+                                                           const int wh_size,
+                                                           const int num,
+                                                           const int cout,
                                                            T *out) {
-  int tid = threadIdx.x;
-
-  int channel_size = n / c;
-  const T *in_c = in + blockIdx.x * channel_size;
-  T *out_c = out + blockIdx.x * channel_size;
-
-  T s = scale[blockIdx.x];
-  T inv_s = inverse(s);
+  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
 
-  for (int i = tid; i < channel_size; i += blockDim.x) {
-    T x = in_c[i];
+  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
+    T s = scale[(i / wh_size) % cout];
+    T inv_s = inverse(s);
+    T x = in[i];
     if (round_type == 0) {
       x = bin_cnt * inv_s * x;
       x = roundWithTiesToEven(x);
@@ -611,12 +607,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
       T min_bound = -bin_cnt - static_cast<T>(1);
       x = x > max_bound ? max_bound : x;
       x = x < min_bound ? min_bound : x;
-      out_c[i] = (x * s) / bin_cnt;
+      out[i] = (x * s) / bin_cnt;
     } else {
       T v = x > s ? s : x;
       v = v < -s ? -s : v;
       v = bin_cnt * inv_s * v;
-      out_c[i] = round(v) * s / bin_cnt;
+      out[i] = round(v) * s / bin_cnt;
     }
   }
 }
@@ -627,19 +623,16 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
                                                            const T *scale,
                                                            const int bin_cnt,
                                                            const int round_type,
-                                                           const int n,
-                                                           const int cin,
+                                                           const int wh_size,
+                                                           const int num,
                                                            const int cout,
                                                            T *out) {
-  T s = scale[blockIdx.x % cout];
-  T inv_s = inverse(s);
-
-  int wh_size = n / (cin * cout);
-  const T *in_c = in + blockIdx.x * wh_size;
-  T *out_c = out + blockIdx.x * wh_size;
+  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
 
-  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
-    T x = in_c[i];
+  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
+    T s = scale[(i / wh_size) % cout];
+    T inv_s = inverse(s);
+    T x = in[i];
     if (round_type == 0) {
       x = bin_cnt * inv_s * x;
       x = roundWithTiesToEven(x);
@@ -647,12 +640,12 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
       T min_bound = -bin_cnt - static_cast<T>(1);
       x = x > max_bound ? max_bound : x;
       x = x < min_bound ? min_bound : x;
-      out_c[i] = (x * s) / bin_cnt;
+      out[i] = (x * s) / bin_cnt;
     } else {
       T v = x > s ? s : x;
       v = v < -s ? -s : v;
       v = bin_cnt * inv_s * v;
-      out_c[i] = round(v) * s / bin_cnt;
+      out[i] = round(v) * s / bin_cnt;
     }
   }
 }
@@ -682,30 +675,39 @@ struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, T> {
     const T *scale_data = scale.data<T>();
     T *out_data = out->mutable_data<T>(ctx.GetPlace());
 
+    int64_t block_size =
+        std::min(static_cast<int64_t>(num),
+                 static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
+
+    int64_t max_threads = ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
+    const int64_t max_blocks =
+        std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
+    const int64_t grid_size =
+        std::min(max_blocks, (num + block_size - 1) / block_size);
+
     if (quant_axis == 0) {
-      int grid = in_dims[0];
-      int block = 1024;
+      const int window_size = num / in_dims[0];
       ChannelClipAndQuantDequantKernelQuantAxis0<T>
-          <<<grid, block, 0, ctx.stream()>>>(in_data,
-                                             scale_data,
-                                             bin_cnt,
-                                             round_type,
-                                             num,
-                                             in_dims[0],
-                                             out_data);
+          <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
+                                                       scale_data,
+                                                       bin_cnt,
+                                                       round_type,
+                                                       window_size,
+                                                       num,
+                                                       in_dims[0],
+                                                       out_data);
     } else if (quant_axis == 1) {
-      int grid = in_dims[0] * in_dims[1];
-      int block = 1024;
+      const int window_size = num / (in_dims[0] * in_dims[1]);
 
       ChannelClipAndQuantDequantKernelQuantAxis1<T>
-          <<<grid, block, 0, ctx.stream()>>>(in_data,
-                                             scale_data,
-                                             bin_cnt,
-                                             round_type,
-                                             num,
-                                             in_dims[0],
-                                             in_dims[1],
-                                             out_data);
+          <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
+                                                       scale_data,
+                                                       bin_cnt,
+                                                       round_type,
+                                                       window_size,
+                                                       num,
+                                                       in_dims[1],
+                                                       out_data);
     }
   }
 };