From 35ca30090b872a57b19f192def126505d2b3a574 Mon Sep 17 00:00:00 2001
From: zmxdream <zhangminxu01@baidu.com>
Date: Thu, 30 Jun 2022 15:52:55 +0800
Subject: [PATCH] Revert "[GPUPS]Optimize dymf kernel (#43911)" (#43958)

* Revert "[GPUPS]Optimize dymf kernel (#43911)"
---
 .../fleet/heter_ps/hashtable_kernel.cu        |  50 ++---
 .../framework/fleet/heter_ps/heter_comm_inl.h |   1 -
 .../fleet/heter_ps/heter_comm_kernel.cu       | 174 ++----------------
 .../fleet/heter_ps/heter_comm_kernel.h        |  33 +---
 4 files changed, 41 insertions(+), 217 deletions(-)

diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
index a7e00bb083f..92df8d8581a 100644
--- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
+++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
@@ -89,42 +89,30 @@ __global__ void dy_mf_search_kernel(Table* table,
                                     char* vals,
                                     size_t len,
                                     size_t pull_feature_value_size) {
-  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
-  const size_t k = threadIdx.x;
+  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
   if (i < len) {
     auto it = table->find(keys[i]);
+
     if (it != table->end()) {
       uint64_t offset = i * pull_feature_value_size;
       FeatureValue* cur = (FeatureValue*)(vals + offset);
       FeatureValue& input = *(FeatureValue*)(it->second);
-      char* cur_p = (char*)cur;
-      char* input_p = (char*)(&input);
-      int len = 9 + input.mf_dim + 1;
-      if (k == 3 || k == 6 || k == 7)
-        *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
-      else if (k < 8)
-        *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
-      else if (k == 8) {
-        *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
-      } else {
-        int len_per_thread = (len - 9) / (blockDim.y - 9);
-        int remain = (len - 9) % (blockDim.y - 9);
-        int real_len = len_per_thread;
-        if ((k - 9) < remain) real_len++;
-        int left = -1, right = -1;
-        if ((k - 9) < remain) {
-          left = 9 + (k - 9) * (len_per_thread + 1);
-          right = left + real_len;
-        } else {
-          left = 9 + remain * (len_per_thread + 1) +
-                 (k - 9 - remain) * len_per_thread;
-          right = left + real_len;
-        }
-        for (int j = left; j < right; j++)
-          *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
+      cur->slot = input.slot;
+      cur->show = input.show;
+      cur->clk = input.clk;
+      cur->mf_dim = input.mf_dim;
+      cur->lr = input.lr;
+      cur->mf_size = input.mf_size;
+      cur->cpu_ptr = input.cpu_ptr;
+      cur->delta_score = input.delta_score;
+      cur->lr_g2sum = input.lr_g2sum;
+      for (int j = 0; j < cur->mf_dim + 1; ++j) {
+        cur->mf[j] = input.mf[j];
       }
     } else {
-      if (keys[i] != 0) printf("pull miss key: %llu", keys[i]);
+      if (keys[i] != 0) {
+        printf("warning::pull miss key: %llu", keys[i]);
+      }
     }
   }
 }
@@ -231,10 +219,8 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
   if (len == 0) {
     return;
   }
-  dim3 block_dims(32, 32);
-  const int grid_size = (len - 1) / 32 + 1;
-  dim3 grid_dims(grid_size);
-  dy_mf_search_kernel<<<grid_dims, block_dims, 0, stream>>>(
+  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
+  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
       container_, d_keys, d_vals, len, pull_feature_value_size_);
 }
 
diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
index 8952039299d..ace533cb0c7 100644
--- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
@@ -760,7 +760,6 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
                                      (char*)d_grads,
                                      (char*)d_merge_grads_ptr,
                                      uniq_len,
-                                     max_mf_dim_,
                                      grad_value_size,
                                      merger_,
                                      stream);
diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
index 8a13d9abe63..fd0dd1a72cc 100644
--- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
+++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
@@ -144,106 +144,28 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys,
   }
 }
 
-// optimized version
-template <>
-__global__ void
-dy_mf_fill_shard_grads_kernel<FeatureKey, FeaturePushValue, int>(
-    FeatureKey* d_shard_keys,
-    FeatureKey* d_keys,
-    FeaturePushValue* d_shard_grads,
-    FeaturePushValue* d_grads,
-    int* idx,
-    size_t len,
-    size_t grad_value_size) {
-  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
-  const size_t k = threadIdx.x;
-  if (i < len) {
-    if (k == 0) {
-      d_shard_keys[i] = d_keys[idx[i]];
-    }
-    FeaturePushValue* cur =
-        (FeaturePushValue*)((char*)d_shard_grads + i * grad_value_size);
-    FeaturePushValue& input = *(
-        FeaturePushValue*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size);
-    char* cur_p = (char*)cur;
-    char* input_p = (char*)(&input);
-    int len = 5 + input.mf_dim;
-    if (k == 2 || k == 4)
-      *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
-    else if (k < 5)
-      *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
-    else {
-      int len_per_thread = (len - 5) / (blockDim.y - 5);
-      int remain = (len - 5) % (blockDim.y - 5);
-      int real_len = len_per_thread;
-      if ((k - 5) < remain) real_len++;
-      int left = -1, right = -1;
-      if ((k - 5) < remain) {
-        left = 5 + (k - 5) * (len_per_thread + 1);
-        right = left + real_len;
-      } else {
-        left = 5 + remain * (len_per_thread + 1) +
-               (k - 5 - remain) * len_per_thread;
-        right = left + real_len;
-      }
-      for (int j = left; j < right; j++)
-        *(float*)(cur_p + j * 4) = *(float*)(input_p + j * 4);
-    }
-  }
-}
-
-__global__ void merge_gradients_basic_kernel(const uint32_t* offset,
-                                             const uint32_t* fea_num,
-                                             const uint32_t* index,
-                                             const char* input,
-                                             char* output,
-                                             int n,
-                                             size_t grad_value_size,
-                                             DynamicGradMerger& merger) {
+__global__ void merge_gradients_kernel(const uint32_t* offset,
+                                       const uint32_t* fea_num,
+                                       const uint32_t* index,
+                                       const char* input,
+                                       char* output,
+                                       int n,
+                                       size_t grad_value_size,
+                                       DynamicGradMerger& merger_) {
   const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
   if (i < n) {
     uint32_t start = offset[i];
     uint32_t num = fea_num[i];
     int ori_index = index[start];
-    FeaturePushValue& lhs = *(FeaturePushValue*)(output + i * grad_value_size);
+    FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size);
     FeaturePushValue& in =
         *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
-    merger.update_basic(lhs, in);
+    merger_.update_one(out, in);
     for (int j = 1; j < num; ++j) {
       ori_index = index[start + j];
       FeaturePushValue& rhs =
           *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
-      merger.merge_basic(lhs, rhs);
-    }
-  }
-}
-
-__global__ void merge_gradients_embedx_kernel(const uint32_t* offset,
-                                              const uint32_t* fea_num,
-                                              const uint32_t* index,
-                                              const char* input,
-                                              char* output,
-                                              int n,
-                                              size_t grad_dim,
-                                              size_t grad_value_size,
-                                              DynamicGradMerger& merger) {
-  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
-  if (i < n) {
-    size_t value_idx = i / grad_dim;
-    size_t field_idx = i % grad_dim;
-    uint32_t start = offset[value_idx];
-    uint32_t num = fea_num[value_idx];
-    int ori_index = index[start];
-    FeaturePushValue& in =
-        *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
-    FeaturePushValue& lhs =
-        *(FeaturePushValue*)(output + value_idx * grad_value_size);
-    merger.update_embedx(lhs, in, field_idx);
-    for (int j = 1; j < num; ++j) {
-      int ori_index = index[start + j];
-      FeaturePushValue& rhs =
-          *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
-      merger.merge_embedx(lhs, rhs, field_idx);
+      merger_.merge_one(out, rhs);
     }
   }
 }
@@ -262,49 +184,6 @@ __global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals,
   }
 }
 
-// optimized version
-template <>
-__global__ void dy_mf_fill_dvals_kernel<FeatureValue, int>(
-    FeatureValue* d_shard_vals,
-    FeatureValue* d_vals,
-    int* idx,
-    size_t len,
-    size_t val_size) {
-  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
-  const size_t k = threadIdx.x;
-  if (i < len) {
-    uint64_t new_offset = uint64_t(idx[i]) * val_size;
-    FeatureValue* cur = (FeatureValue*)((char*)d_vals + new_offset);
-    FeatureValue& input = *(FeatureValue*)((char*)d_shard_vals + i * val_size);
-    char* cur_p = (char*)cur;
-    char* input_p = (char*)(&input);
-    int len = 9 + input.mf_dim + 1;
-    if (k == 3 || k == 6 || k == 7)
-      *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
-    else if (k < 8)
-      *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
-    else if (k == 8) {
-      *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
-    } else {
-      int len_per_thread = (len - 9) / (blockDim.x - 9);
-      int remain = (len - 9) % (blockDim.y - 9);
-      int real_len = len_per_thread;
-      if ((k - 9) < remain) real_len++;
-      int left = -1, right = -1;
-      if ((k - 9) < remain) {
-        left = 9 + (k - 9) * (len_per_thread + 1);
-        right = left + real_len;
-      } else {
-        left = 9 + remain * (len_per_thread + 1) +
-               (k - 9 - remain) * len_per_thread;
-        right = left + real_len;
-      }
-      for (int j = left; j < right; j++)
-        *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
-    }
-  }
-}
-
 // cuda implemention of  heter_comm_kernel.h
 template <typename T, typename StreamType>
 void HeterCommKernel::fill_idx(T* idx,
@@ -442,12 +321,9 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys,
                                              long long len,
                                              size_t grad_value_size,
                                              const StreamType& stream) {
-  // int grid_size = (len - 1) / block_size_ + 1;
+  int grid_size = (len - 1) / block_size_ + 1;
   size_t c_len = (size_t)len;
-  dim3 block_dims(32, 32);
-  const size_t grid_size = (len - 1) / 32 + 1;
-  dim3 grid_dims(grid_size);
-  dy_mf_fill_shard_grads_kernel<<<grid_dims, block_dims, 0, stream>>>(
+  dy_mf_fill_shard_grads_kernel<<<grid_size, block_size_, 0, stream>>>(
       d_shard_keys,
       d_keys,
       d_shard_grads,
@@ -464,26 +340,12 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset,
                                      const char* input,
                                      char* output,
                                      int n,
-                                     size_t grad_dim,
                                      size_t grad_value_size,
                                      DynamicGradMerger& merger_,
                                      const StreamType& stream) {
   int grid_size = (n - 1) / block_size_ + 1;
-  merge_gradients_basic_kernel<<<grid_size, block_size_, 0, stream>>>(
+  merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
       offset, fea_num, index, input, output, n, grad_value_size, merger_);
-  if (grad_dim > 0) {
-    int grid_size2 = (n * grad_dim - 1) / block_size_ + 1;
-    merge_gradients_embedx_kernel<<<grid_size2, block_size_, 0, stream>>>(
-        offset,
-        fea_num,
-        index,
-        input,
-        output,
-        n * grad_dim,
-        grad_dim,
-        grad_value_size,
-        merger_);
-  }
 }
 
 template <typename ValType, typename T, typename StreamType>
@@ -493,12 +355,9 @@ void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals,
                                        long long len,
                                        size_t val_size,
                                        const StreamType& stream) {
-  // int grid_size = (len - 1) / block_size_ + 1;
+  int grid_size = (len - 1) / block_size_ + 1;
   size_t c_len = (size_t)len;
-  dim3 block_dims(32, 32);
-  const size_t grid_size_ = (len - 1) / 32 + 1;
-  dim3 grid_dims(grid_size_);
-  dy_mf_fill_dvals_kernel<<<grid_dims, block_dims, 0, stream>>>(
+  dy_mf_fill_dvals_kernel<<<grid_size, block_size_, 0, stream>>>(
       d_shard_vals, d_vals, idx, c_len, val_size);
 }
 
@@ -628,7 +487,6 @@ template void HeterCommKernel::merge_gradient<cudaStream_t>(
     const char* input,
     char* output,
     int n,
-    size_t grad_dim,
     size_t grad_value_size,
     DynamicGradMerger& merger_,
     const cudaStream_t& stream);
diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
index 6859161a5fe..d1555dc2e09 100644
--- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
+++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
@@ -42,41 +42,23 @@ struct DynamicGradMerger {
   }
 
   template <typename T>
-  __device__ __forceinline__ void update_basic(T& output, const T& input) {
+  __device__ __forceinline__ void update_one(T& output, const T& input) {
     output.slot = input.slot;
     output.show = input.show;
     output.clk = input.clk;
     output.mf_dim = input.mf_dim;
     output.lr_g = input.lr_g;
-    // for (int i = 0; i < output.mf_dim; ++i) {
-    //  output.mf_g[i] = input.mf_g[i];
-    //}
+    for (int i = 0; i < output.mf_dim; ++i) {
+      output.mf_g[i] = input.mf_g[i];
+    }
   }
   template <typename T>
-  __device__ __forceinline__ void merge_basic(T& output, const T& input) {
+  __device__ __forceinline__ void merge_one(T& output, const T& input) {
     output.show += input.show;
     output.clk += input.clk;
     output.lr_g += input.lr_g;
-    // for (int i = 0; i < input.mf_dim; ++i) {
-    //  output.mf_g[i] += input.mf_g[i];
-    //}
-  }
-
-  template <typename T>
-  __device__ __forceinline__ void update_embedx(T& output,
-                                                const T& input,
-                                                size_t embedx_id) {
-    if (embedx_id < output.mf_dim) {
-      output.mf_g[embedx_id] = input.mf_g[embedx_id];
-    }
-  }
-
-  template <typename T>
-  __device__ __forceinline__ void merge_embedx(T& output,
-                                               const T& input,
-                                               size_t embedx_id) {
-    if (embedx_id < output.mf_dim) {
-      output.mf_g[embedx_id] += input.mf_g[embedx_id];
+    for (int i = 0; i < input.mf_dim; ++i) {
+      output.mf_g[i] += input.mf_g[i];
     }
   }
 };
@@ -183,7 +165,6 @@ class HeterCommKernel {
                       const char* input,
                       char* output,
                       int n,
-                      size_t grad_dim,
                       size_t grad_value_size,
                       DynamicGradMerger& merger_,
                       const StreamType& stream);
-- 
GitLab