From e9495400621d9bba9cf7d8553e9fe30424c42229 Mon Sep 17 00:00:00 2001
From: yujianfeng <yujianfeng5@huawei.com>
Date: Thu, 6 Aug 2020 16:39:05 +0800
Subject: [PATCH] Support int64 for cpu sparse optimizers

---
 .../backend/kernel_compiler/common_utils.cc   | 343 --------------
 .../backend/kernel_compiler/common_utils.h    |  39 --
 .../cpu/embedding_look_up_cpu_kernel.cc       |  32 +-
 .../cpu/embedding_look_up_cpu_kernel.h        |   9 +
 .../cpu/sparse_apply_adam_cpu_kernel.cc       |  77 +--
 .../cpu/sparse_apply_adam_cpu_kernel.h        |  32 +-
 .../cpu/sparse_apply_ftrl_cpu_kernel.cc       |  69 ++-
 .../cpu/sparse_apply_ftrl_cpu_kernel.h        |  25 +-
 .../cpu/sparse_apply_lazy_adam_cpu_kernel.cc  |  68 ++-
 .../cpu/sparse_apply_lazy_adam_cpu_kernel.h   |  32 +-
 ...parse_apply_proximal_adagrad_cpu_kernel.cc |  67 ++-
 ...sparse_apply_proximal_adagrad_cpu_kernel.h |  29 +-
 .../cpu/sparse_optimizer_cpu_kernel.h         | 442 ++++++++++++++++++
 .../ccsrc/backend/optimizer/common/helper.cc  |  11 +-
 .../runtime/device/cpu/cpu_kernel_runtime.cc  |  27 +-
 mindspore/ccsrc/utils/convert_utils.cc        |   4 +-
 .../sparse_optimizer_cpu_kernel_test.cc}      |  22 +-
 17 files changed, 786 insertions(+), 542 deletions(-)
 create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h
 rename tests/ut/cpp/kernel/{common_utils_test.cc => cpu/sparse_optimizer_cpu_kernel_test.cc} (80%)

diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
index ec3652ebc..9bacf4168 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
@@ -503,332 +503,6 @@ int Sign(float x) {
   return 0;
 }
 
-namespace {
-struct BucketSparseGradient {
-  float *value_;
-  int *indices_;
-  int *global_indices_;
-  size_t indices_size_;
-};
-
-struct MultiThreadReduceSparseGradientParam {
-  SparseGradient *input_grad_{nullptr};
-  SparseGradient *workspace_grad_{nullptr};
-  SparseGradient *output_grad_{nullptr};
-  size_t max_index_{0};
-  size_t value_stride_{0};
-  size_t thread_num_{0};
-  bool use_sort_reduce_{false};
-};
-
-void CalculateEachBucketSize(const std::shared_ptr<SparseGradient> &sparse_grad, size_t max_index,
-                             std::vector<size_t> *each_bucket_size) {
-  MS_LOG(DEBUG) << "Start";
-  MS_EXCEPTION_IF_NULL(sparse_grad);
-  MS_EXCEPTION_IF_NULL(sparse_grad->indices_);
-  MS_EXCEPTION_IF_NULL(each_bucket_size);
-  size_t bucket_num = each_bucket_size->size();
-  for (size_t i = 0; i < sparse_grad->indices_size_; ++i) {
-    int index = sparse_grad->indices_[i];
-    if (index >= 0 && IntToSize(index) < max_index) {
-      auto bucket_id = index % bucket_num;
-      each_bucket_size->at(bucket_id)++;
-    }
-  }
-  MS_LOG(DEBUG) << "End";
-}
-
-void SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam &param,
-                                        std::vector<std::shared_ptr<SparseGradient>> *segments_ptr,
-                                        std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) {
-  MS_EXCEPTION_IF_NULL(param.input_grad_);
-  MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr);
-  MS_EXCEPTION_IF_NULL(segments_ptr);
-  auto &segments = *segments_ptr;
-  auto &segment_bucket_sizes = *segment_bucket_sizes_ptr;
-  auto input_grad = param.input_grad_;
-  if (param.thread_num_ < 1) {
-    MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
-  }
-  size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
-  size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
-  std::vector<std::thread> threads;
-  threads.reserve(param.thread_num_);
-  segments.reserve(param.thread_num_);
-
-  size_t current_indices_offset = 0;
-  for (size_t i = 0; i < param.thread_num_; ++i) {
-    segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0));
-    size_t indices_size = thread_indices_size;
-    if (i < left_indices_size) {
-      indices_size += 1;
-    }
-    segments.emplace_back(std::make_shared<SparseGradient>());
-    segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
-    segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
-    segments[i]->indices_size_ = indices_size;
-    threads.emplace_back(
-      std::thread(CalculateEachBucketSize, segments[i], param.max_index_, segment_bucket_sizes[i].get()));
-    current_indices_offset += indices_size;
-  }
-
-  for (size_t i = 0; i < param.thread_num_; ++i) {
-    threads[i].join();
-  }
-}
-
-void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam &param,
-                                const std::shared_ptr<SparseGradient> &segment, size_t bucket_offset,
-                                const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets) {
-  MS_LOG(DEBUG) << "Start";
-  MS_EXCEPTION_IF_NULL(segment);
-  MS_EXCEPTION_IF_NULL(segment->indices_);
-  std::vector<size_t> bucket_data_num(param.thread_num_, 0);
-  for (size_t i = 0; i < segment->indices_size_; ++i) {
-    int index = segment->indices_[i];
-    if (index >= 0 && IntToSize(index) < param.max_index_) {
-      auto bucket_id = index % param.thread_num_;
-      auto bucket_index = bucket_data_num[bucket_id];
-      buckets[bucket_id]->indices_[bucket_index] = index;
-      buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i;
-      bucket_data_num[bucket_id]++;
-    }
-  }
-  MS_LOG(DEBUG) << "End";
-}
-
-void GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam &param,
-                                        const std::vector<std::shared_ptr<SparseGradient>> &segments,
-                                        const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes,
-                                        std::vector<std::shared_ptr<BucketSparseGradient>> *buckets_ptr) {
-  MS_EXCEPTION_IF_NULL(param.output_grad_);
-  MS_EXCEPTION_IF_NULL(param.output_grad_->value_);
-  MS_EXCEPTION_IF_NULL(param.output_grad_->indices_);
-  MS_EXCEPTION_IF_NULL(buckets_ptr);
-  auto &buckets = *buckets_ptr;
-  size_t thread_num = param.thread_num_;
-  if (thread_num != segment_bucket_sizes.size()) {
-    MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!";
-  }
-  std::vector<size_t> bucket_data_size(thread_num, 0);
-  for (size_t i = 0; i < thread_num; ++i) {
-    for (size_t j = 0; j < thread_num; ++j) {
-      bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
-    }
-  }
-  size_t current_indices_offset = 0;
-  for (size_t i = 0; i < thread_num; ++i) {
-    buckets.emplace_back(std::make_shared<BucketSparseGradient>());
-    buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_;
-    buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset;
-    buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset;
-    buckets[i]->indices_size_ = bucket_data_size[i];
-    current_indices_offset += bucket_data_size[i];
-  }
-  std::vector<size_t> tmp_bucket_data_size(thread_num, 0);
-  std::vector<std::vector<std::shared_ptr<BucketSparseGradient>>> each_thread_buckets;
-  for (size_t i = 0; i < thread_num; ++i) {
-    std::vector<std::shared_ptr<BucketSparseGradient>> thread_buckets;
-    for (size_t j = 0; j < thread_num; ++j) {
-      thread_buckets.emplace_back(std::make_shared<BucketSparseGradient>());
-      thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j];
-      thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j];
-      thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_;
-      thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j);
-      tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
-    }
-    each_thread_buckets.emplace_back(thread_buckets);
-  }
-  std::vector<std::thread> threads;
-  threads.reserve(thread_num);
-  current_indices_offset = 0;
-  for (size_t i = 0; i < thread_num; ++i) {
-    threads.emplace_back(
-      std::thread(CopySegmentIndicesToBucket, param, segments[i], current_indices_offset, each_thread_buckets[i]));
-    current_indices_offset += segments[i]->indices_size_;
-  }
-  for (size_t i = 0; i < thread_num; ++i) {
-    threads[i].join();
-  }
-}
-
-void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
-                                       const std::shared_ptr<BucketSparseGradient> &bucket,
-                                       const std::shared_ptr<SparseGradient> &reduced_bucket) {
-  MS_LOG(DEBUG) << "Start";
-  MS_EXCEPTION_IF_NULL(bucket);
-  MS_EXCEPTION_IF_NULL(bucket->value_);
-  MS_EXCEPTION_IF_NULL(bucket->indices_);
-  MS_EXCEPTION_IF_NULL(reduced_bucket);
-  MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
-  MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
-  std::vector<std::pair<int, int>> sorted_indices;
-  sorted_indices.reserve(bucket->indices_size_);
-  for (size_t i = 0; i < bucket->indices_size_; ++i) {
-    int index = bucket->indices_[i];
-    int global_index = bucket->global_indices_[i];
-    sorted_indices.emplace_back(std::pair<int, int>(index, global_index));
-  }
-  std::sort(sorted_indices.begin(), sorted_indices.end());
-
-  float *global_value = param.input_grad_->value_;
-  size_t unique_indices_size = 0;
-  size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
-  int last_index{0};
-  size_t value_offset{0};
-  for (size_t i = 0; i < sorted_indices.size(); ++i) {
-    int index = sorted_indices[i].first;
-    int global_index = sorted_indices[i].second;
-    int global_value_offset = global_index * param.value_stride_;
-    if (i == 0 || index != last_index) {
-      if (i != 0) {
-        unique_indices_size++;
-      }
-      reduced_bucket->indices_[unique_indices_size] = index;
-      value_offset = unique_indices_size * param.value_stride_;
-      auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float),
-                               global_value + global_value_offset, param.value_stride_ * sizeof(float));
-      if (ret_code != EOK) {
-        MS_LOG(EXCEPTION) << "Failed to copy data!";
-      }
-    } else {
-      for (size_t j = 0; j < param.value_stride_; ++j) {
-        reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j];
-      }
-    }
-    last_index = index;
-  }
-  reduced_bucket->indices_size_ = unique_indices_size;
-  MS_LOG(DEBUG) << "End";
-}
-
-void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
-                                const std::shared_ptr<BucketSparseGradient> &bucket,
-                                const std::shared_ptr<SparseGradient> &reduced_bucket) {
-  MS_LOG(DEBUG) << "Start";
-  MS_EXCEPTION_IF_NULL(bucket);
-  MS_EXCEPTION_IF_NULL(bucket->value_);
-  MS_EXCEPTION_IF_NULL(bucket->indices_);
-  MS_EXCEPTION_IF_NULL(reduced_bucket);
-  MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
-  MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
-
-  float *global_value = param.input_grad_->value_;
-  std::unordered_map<int, size_t> index_map;
-  size_t unique_indices_size = 0;
-  size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
-  for (size_t i = 0; i < bucket->indices_size_; ++i) {
-    int index = bucket->indices_[i];
-    int global_index = bucket->global_indices_[i];
-    auto iter = index_map.find(index);
-    if (iter == index_map.end()) {
-      reduced_bucket->indices_[unique_indices_size] = index;
-      size_t start_index = unique_indices_size * param.value_stride_;
-      index_map[index] = start_index;
-      auto ret_code = memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float),
-                               global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float));
-      if (ret_code != EOK) {
-        MS_LOG(EXCEPTION) << "Failed to copy data!";
-      }
-      unique_indices_size++;
-    } else {
-      size_t start_index = iter->second;
-      size_t end_index = start_index + param.value_stride_;
-      for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) {
-        reduced_bucket->value_[j] += global_value[k];
-      }
-    }
-  }
-  reduced_bucket->indices_size_ = unique_indices_size;
-  MS_LOG(DEBUG) << "End";
-}
-
-void ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam &param,
-                                           const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets,
-                                           std::vector<std::shared_ptr<SparseGradient>> *reduced_buckets_ptr) {
-  MS_EXCEPTION_IF_NULL(param.workspace_grad_);
-  MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_);
-  MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_);
-  MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
-  auto &reduced_buckets = *reduced_buckets_ptr;
-  size_t thread_num = buckets.size();
-  std::vector<std::thread> threads;
-  threads.reserve(thread_num);
-
-  size_t current_indices_offset = 0;
-  for (size_t i = 0; i < thread_num; ++i) {
-    reduced_buckets.emplace_back(std::make_shared<SparseGradient>());
-    reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
-    reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
-    reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
-    if (param.use_sort_reduce_) {
-      threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
-    } else {
-      threads.emplace_back(std::thread(ReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
-    }
-    current_indices_offset += buckets[i]->indices_size_;
-  }
-  for (size_t i = 0; i < thread_num; ++i) {
-    threads[i].join();
-  }
-}
-
-void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam &param,
-                               const std::vector<std::shared_ptr<SparseGradient>> &reduced_buckets) {
-  MS_EXCEPTION_IF_NULL(param.output_grad_);
-  auto output_grad = param.output_grad_;
-  MS_EXCEPTION_IF_NULL(output_grad->value_);
-  MS_EXCEPTION_IF_NULL(output_grad->indices_);
-  size_t stride_data_size = param.value_stride_ * sizeof(float);
-  size_t unique_indices_size = 0;
-  for (size_t i = 0; i < reduced_buckets.size(); ++i) {
-    auto &bucket = reduced_buckets[i];
-    MS_EXCEPTION_IF_NULL(bucket);
-    if (bucket->indices_size_ == 0) {
-      continue;
-    }
-    auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_,
-                             (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_,
-                             bucket->indices_size_ * stride_data_size);
-    if (ret_code != EOK) {
-      MS_LOG(EXCEPTION) << "Failed to copy data!";
-    }
-    ret_code = memcpy_s(output_grad->indices_ + unique_indices_size,
-                        (output_grad->indices_size_ - unique_indices_size) * sizeof(int), bucket->indices_,
-                        bucket->indices_size_ * sizeof(int));
-    if (ret_code != EOK) {
-      MS_LOG(EXCEPTION) << "Failed to copy data!";
-    }
-    unique_indices_size += bucket->indices_size_;
-  }
-  output_grad->indices_size_ = unique_indices_size;
-}
-}  // namespace
-
-void BucketReduceSparseGradient(const ReduceSparseGradientParam &param) {
-  MS_LOG(DEBUG) << "Start";
-  MS_EXCEPTION_IF_NULL(param.input_grad_);
-  size_t thread_num = 23;
-  if (param.input_grad_->indices_size_ < thread_num) {
-    thread_num = param.input_grad_->indices_size_;
-  }
-  MultiThreadReduceSparseGradientParam multi_thread_param({param.input_grad_, param.workspace_grad_, param.output_grad_,
-                                                           param.max_index_, param.value_stride_, thread_num,
-                                                           param.use_sort_reduce_});
-  std::vector<std::shared_ptr<SparseGradient>> segments;
-  std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes;
-  SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes);
-
-  std::vector<std::shared_ptr<BucketSparseGradient>> buckets;
-  GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets);
-
-  std::vector<std::shared_ptr<SparseGradient>> reduced_buckets;
-  ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets);
-
-  MergeReduceSparseGradient(multi_thread_param, reduced_buckets);
-  MS_LOG(DEBUG) << "End";
-}
-
 std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
   MS_EXCEPTION_IF_NULL(anf_node);
 
@@ -1073,23 +747,6 @@ bool IsWeightBoundary(const AnfNodePtr &node) {
   return false;
 }
 
-void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
-                        size_t total_compute_size) {
-  const size_t kThreadNum = 24;
-  std::vector<std::thread> threads;
-  threads.reserve(kThreadNum);
-  size_t start = 0;
-  size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum;
-  while (start < total_compute_size) {
-    size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
-    threads.emplace_back(std::thread(func, params, start, end));
-    start += once_compute_size;
-  }
-  for (size_t i = 0; i < threads.size(); ++i) {
-    threads[i].join();
-  }
-}
-
 std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
   if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
       AnfAlgo::GetInputTensorNum(cnode) != 1) {
diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h
index a894435bc..738859a60 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h
@@ -71,42 +71,6 @@ class KernelMeta {
   std::unordered_map<std::string, std::string> kernel_meta_map_;
 };
 
-struct SparseGradient {
-  float *value_{nullptr};
-  int *indices_{nullptr};
-  size_t indices_size_{0};
-};
-
-struct ReduceSparseGradientParam {
-  SparseGradient *input_grad_{nullptr};
-  SparseGradient *workspace_grad_{nullptr};
-  SparseGradient *output_grad_{nullptr};
-  size_t max_index_{0};
-  size_t value_stride_{0};
-  bool use_sort_reduce_{false};
-};
-
-struct MultiThreadComputeParams {
-  float *var_;
-  float *accum_;
-  float *linear_;
-  float *m_;
-  float *m_t_;
-  float *v_;
-  float lr_;
-  float l1_;
-  float l2_;
-  float lr_power_;
-  float beta1_;
-  float beta2_;
-  float epsilon_;
-  SparseGradient sparse_grad_;
-  size_t var_first_dim_size_;
-  size_t var_outer_dim_size_;
-  bool use_nesterov_;
-};
-using MultiThreadComputeFunc = std::function<void(MultiThreadComputeParams *param, size_t start, size_t end)>;
-
 bool CheckCache(const std::string &kernel_name);
 KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor);
 KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
@@ -132,9 +96,6 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
 bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
 void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
 bool IsWeightBoundary(const AnfNodePtr &node);
-void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
-                        size_t total_compute_size);
-void BucketReduceSparseGradient(const ReduceSparseGradientParam &param);
 std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode);
 }  // namespace kernel
 }  // namespace mindspore
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
index 2591e57eb..54106d166 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
@@ -22,11 +22,12 @@
 namespace mindspore {
 namespace kernel {
 namespace {
-void LookUpTableTask(const float *input_addr, const int *indices_addr, float *output_addr, size_t indices_lens,
-                     size_t outer_dim_size, int offset, size_t first_dim_size) {
+template <typename T>
+void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens,
+                     size_t outer_dim_size, T offset, size_t first_dim_size) {
   size_t lens = outer_dim_size * sizeof(float);
   for (size_t i = 0; i < indices_lens; ++i) {
-    int index = indices_addr[i] - offset;
+    T index = indices_addr[i] - offset;
     if (index >= 0 && index < SizeToInt(first_dim_size)) {
       size_t pos = index * outer_dim_size;
       auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens);
@@ -61,13 +62,14 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
   if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
     offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
   }
+  indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
 }
 
-bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
-                                      const std::vector<kernel::AddressPtr> & /*workspace*/,
-                                      const std::vector<kernel::AddressPtr> &outputs) {
+template <typename T>
+void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                                            const std::vector<kernel::AddressPtr> &outputs) const {
   auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
-  auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
+  auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
   auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
   const size_t thread_num = 16;
   std::thread threads[16];
@@ -80,9 +82,9 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
       break;
     }
     MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens;
-    threads[i] =
-      std::thread(LookUpTableTask, input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_,
-                  task_proc_lens, outer_dim_size_, offset_, first_dim_size_);
+    threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset,
+                             output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_,
+                             first_dim_size_);
     task_offset += task_proc_lens;
     if (task_offset + task_proc_lens > indices_lens_) {
       task_proc_lens = indices_lens_ - task_offset;
@@ -91,6 +93,16 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   for (size_t j = 0; j < i; j++) {
     threads[j].join();
   }
+}
+
+bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
+                                      const std::vector<kernel::AddressPtr> & /*workspace*/,
+                                      const std::vector<kernel::AddressPtr> &outputs) {
+  if (indices_data_type_ == kNumberTypeInt32) {
+    LaunchKernel<int>(inputs, outputs);
+  } else {
+    LaunchKernel<int64_t>(inputs, outputs);
+  }
   return true;
 }
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
index bbde7157c..c3ffc851d 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
@@ -31,6 +31,9 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
 
   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs) override;
+  template <typename T>
+  void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                    const std::vector<kernel::AddressPtr> &outputs) const;
 
  protected:
   void CheckParam(const CNodePtr &kernel_node);
@@ -38,12 +41,18 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
   size_t indices_lens_{1};
   size_t first_dim_size_{1};
   size_t outer_dim_size_{1};
+  TypeId indices_data_type_{kNumberTypeInt32};
 };
 
 MS_REG_CPU_KERNEL(
   EmbeddingLookup,
   KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
   EmbeddingLookUpCPUKernel);
+
+MS_REG_CPU_KERNEL(
+  EmbeddingLookup,
+  KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
+  EmbeddingLookUpCPUKernel);
 }  // namespace kernel
 }  // namespace mindspore
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc
index bd57a022f..2c66b519f 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc
@@ -22,7 +22,8 @@ namespace kernel {
 namespace {
 constexpr size_t kSparseApplyAdamInputSize = 11;
 
-void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeAdam(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto m = input_params->m_;
   auto m_t = input_params->m_t_;
@@ -34,8 +35,8 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en
   const auto var_first_dim_size = input_params->var_first_dim_size_;
   const auto var_outer_dim_size = input_params->var_outer_dim_size_;
   for (size_t i = start; i < end; ++i) {
-    int index = unique_sparse_grad.indices_[i];
-    if (index < 0 || IntToSize(index) >= var_first_dim_size) {
+    T index = unique_sparse_grad.indices_[i];
+    if (index < 0 || LongToSize(index) >= var_first_dim_size) {
       MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
     }
     size_t start_index = var_outer_dim_size * index;
@@ -51,7 +52,8 @@ void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t en
   }
 }
 
-void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeMomentum(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto m = input_params->m_;
   auto v = input_params->v_;
@@ -63,7 +65,8 @@ void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_
   }
 }
 
-void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeWeight(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto var = input_params->var_;
   const auto *m = input_params->m_;
@@ -76,16 +79,24 @@ void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t
 }
 }  // namespace
 
-void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
-  CPUKernel::InitInputOutputSize(kernel_node);
-  MS_EXCEPTION_IF_NULL(kernel_node);
+template <typename T>
+void SparseApplyAdamCPUKernel::InitWorkspaceSize() {
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
   workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float));
 }
 
+void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
+  CPUKernel::InitInputOutputSize(kernel_node);
+  if (indices_data_type_ == kNumberTypeInt32) {
+    InitWorkspaceSize<int>();
+  } else {
+    InitWorkspaceSize<int64_t>();
+  }
+}
+
 void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
   MS_EXCEPTION_IF_NULL(kernel_node);
   std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
@@ -119,15 +130,12 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
   if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
     use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
   }
+  indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 10);
 }
 
-bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
-                                      const std::vector<kernel::AddressPtr> &workspace,
-                                      const std::vector<kernel::AddressPtr> & /*outputs*/) {
-  if (inputs.size() < kSparseApplyAdamInputSize) {
-    MS_LOG(EXCEPTION) << "Error input size!";
-  }
-
+template <typename T>
+void SparseApplyAdamCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                                            const std::vector<kernel::AddressPtr> &workspace) const {
   auto var = reinterpret_cast<float *>(inputs[0]->addr);
   auto m = reinterpret_cast<float *>(inputs[1]->addr);
   auto v = reinterpret_cast<float *>(inputs[2]->addr);
@@ -141,17 +149,17 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   auto beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
   auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
   auto grad = reinterpret_cast<float *>(inputs[9]->addr);
-  auto indices = reinterpret_cast<int *>(inputs[10]->addr);
+  auto indices = reinterpret_cast<T *>(inputs[10]->addr);
   auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
-  auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
+  auto new_indices = reinterpret_cast<T *>(workspace[1]->addr);
   auto workspace_grad = reinterpret_cast<float *>(workspace[2]->addr);
-  auto workspace_indices = reinterpret_cast<int *>(workspace[3]->addr);
+  auto workspace_indices = reinterpret_cast<T *>(workspace[3]->addr);
   auto m_t = reinterpret_cast<float *>(workspace[4]->addr);
 
-  SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
-  SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
-  SparseGradient input_sparse_grad({grad, indices, indices_size_});
-  ReduceSparseGradientParam param;
+  SparseGradient<T> unique_sparse_grad({new_grad, new_indices, indices_size_});
+  SparseGradient<T> workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
+  SparseGradient<T> input_sparse_grad({grad, indices, indices_size_});
+  ReduceSparseGradientParam<T> param;
   param.input_grad_ = &input_sparse_grad;
   param.workspace_grad_ = &workspace_sparse_grad;
   param.output_grad_ = &unique_sparse_grad;
@@ -162,19 +170,19 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_;
   lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
 
-  MultiThreadComputeParams input_params;
+  MultiThreadComputeParams<T> input_params;
   input_params.m_ = m;
   input_params.v_ = v;
   input_params.beta1_ = beta1;
   input_params.beta2_ = beta2;
-  MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size);
+  MultiThreadCompute<T>(ComputeMomentum<T>, &input_params, total_dim_size);
 
   input_params.m_t_ = m_t;
   input_params.use_nesterov_ = use_nesterov_;
   input_params.sparse_grad_ = unique_sparse_grad;
   input_params.var_first_dim_size_ = var_first_dim_size_;
   input_params.var_outer_dim_size_ = var_outer_dim_size_;
-  MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_);
+  MultiThreadCompute<T>(ComputeAdam<T>, &input_params, unique_sparse_grad.indices_size_);
 
   if (use_nesterov_) {
     input_params.m_ = input_params.m_t_;
@@ -182,7 +190,20 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   input_params.var_ = var;
   input_params.lr_ = lr;
   input_params.epsilon_ = epsilon;
-  MultiThreadCompute(ComputeWeight, &input_params, total_dim_size);
+  MultiThreadCompute<T>(ComputeWeight<T>, &input_params, total_dim_size);
+}
+
+bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
+                                      const std::vector<kernel::AddressPtr> &workspace,
+                                      const std::vector<kernel::AddressPtr> & /*outputs*/) {
+  if (inputs.size() < kSparseApplyAdamInputSize) {
+    MS_LOG(EXCEPTION) << "Error input size!";
+  }
+  if (indices_data_type_ == kNumberTypeInt32) {
+    LaunchKernel<int>(inputs, workspace);
+  } else {
+    LaunchKernel<int64_t>(inputs, workspace);
+  }
   return true;
 }
 }  // namespace kernel
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h
index 6cf716839..736d8d484 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h
@@ -17,13 +17,11 @@
 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_
 
 #include <vector>
-#include <memory>
-#include "backend/kernel_compiler/cpu/cpu_kernel.h"
-#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
+#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
 
 namespace mindspore {
 namespace kernel {
-class SparseApplyAdamCPUKernel : public CPUKernel {
+class SparseApplyAdamCPUKernel : public SparseOptimizerCPUKernel {
  public:
   SparseApplyAdamCPUKernel() = default;
   ~SparseApplyAdamCPUKernel() override = default;
@@ -32,11 +30,13 @@ class SparseApplyAdamCPUKernel : public CPUKernel {
   void InitInputOutputSize(const CNodePtr &kernel_node) override;
   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs) override;
+  template <typename T>
+  void InitWorkspaceSize();
+  template <typename T>
+  void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                    const std::vector<kernel::AddressPtr> &workspace) const;
 
  protected:
-  size_t indices_size_{0};
-  size_t var_first_dim_size_{0};
-  size_t var_outer_dim_size_{1};
   bool use_nesterov_{false};
 };
 
@@ -57,6 +57,24 @@ MS_REG_CPU_KERNEL(FusedSparseAdam,
                     .AddOutputAttr(kNumberTypeFloat32)
                     .AddOutputAttr(kNumberTypeFloat32),
                   SparseApplyAdamCPUKernel);
+
+MS_REG_CPU_KERNEL(FusedSparseAdam,
+                  KernelAttr()
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeInt64)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32),
+                  SparseApplyAdamCPUKernel);
 }  // namespace kernel
 }  // namespace mindspore
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc
index e37531022..8149b69a5 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc
@@ -21,8 +21,8 @@ namespace mindspore {
 namespace kernel {
 namespace {
 constexpr size_t kSparseApplyFtrlInputSize = 5;
-
-void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeFtrl(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto var = input_params->var_;
   auto accum = input_params->accum_;
@@ -35,8 +35,8 @@ void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t en
   const auto var_first_dim_size = input_params->var_first_dim_size_;
   const auto var_outer_dim_size = input_params->var_outer_dim_size_;
   for (size_t i = start; i < end; ++i) {
-    int index = unique_sparse_grad.indices_[i];
-    if (index < 0 || IntToSize(index) >= var_first_dim_size) {
+    T index = unique_sparse_grad.indices_[i];
+    if (index < 0 || LongToSize(index) >= var_first_dim_size) {
       MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
     }
     size_t start_index = var_outer_dim_size * index;
@@ -61,13 +61,21 @@ void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t en
 }
 }  // namespace
 
-void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
-  CPUKernel::InitInputOutputSize(kernel_node);
-  MS_EXCEPTION_IF_NULL(kernel_node);
+template <typename T>
+void SparseApplyFtrlCPUKernel::InitWorkspaceSize() {
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
+}
+
+void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
+  CPUKernel::InitInputOutputSize(kernel_node);
+  if (indices_data_type_ == kNumberTypeInt32) {
+    InitWorkspaceSize<int>();
+  } else {
+    InitWorkspaceSize<int64_t>();
+  }
 }
 
 void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -116,29 +124,26 @@ void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) {
   if (lr_power_ > 0) {
     MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar";
   }
+  indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 4);
 }
 
-bool SparseApplyFtrlCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
-                                      const std::vector<kernel::AddressPtr> &workspace,
-                                      const std::vector<kernel::AddressPtr> & /*outputs*/) {
-  if (inputs.size() < kSparseApplyFtrlInputSize) {
-    MS_LOG(EXCEPTION) << "error input output size!";
-  }
-
+template <typename T>
+void SparseApplyFtrlCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                                            const std::vector<kernel::AddressPtr> &workspace) const {
   auto var = reinterpret_cast<float *>(inputs[0]->addr);
   auto accum = reinterpret_cast<float *>(inputs[1]->addr);
   auto linear = reinterpret_cast<float *>(inputs[2]->addr);
   auto grad = reinterpret_cast<float *>(inputs[3]->addr);
-  auto indices = reinterpret_cast<int *>(inputs[4]->addr);
+  auto indices = reinterpret_cast<T *>(inputs[4]->addr);
   auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
-  auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
+  auto new_indices = reinterpret_cast<T *>(workspace[1]->addr);
   auto workspace_grad = reinterpret_cast<float *>(workspace[2]->addr);
-  auto workspace_indices = reinterpret_cast<int *>(workspace[3]->addr);
+  auto workspace_indices = reinterpret_cast<T *>(workspace[3]->addr);
 
-  SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
-  SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
-  SparseGradient input_sparse_grad({grad, indices, indices_size_});
-  ReduceSparseGradientParam param;
+  SparseGradient<T> unique_sparse_grad({new_grad, new_indices, indices_size_});
+  SparseGradient<T> workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
+  SparseGradient<T> input_sparse_grad({grad, indices, indices_size_});
+  ReduceSparseGradientParam<T> param;
   param.input_grad_ = &input_sparse_grad;
   param.workspace_grad_ = &workspace_sparse_grad;
   param.output_grad_ = &unique_sparse_grad;
@@ -146,7 +151,7 @@ bool SparseApplyFtrlCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   param.value_stride_ = var_outer_dim_size_;
   BucketReduceSparseGradient(param);
 
-  MultiThreadComputeParams input_params;
+  MultiThreadComputeParams<T> input_params;
   input_params.var_ = var;
   input_params.accum_ = accum;
   input_params.linear_ = linear;
@@ -157,7 +162,21 @@ bool SparseApplyFtrlCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
   input_params.sparse_grad_ = unique_sparse_grad;
   input_params.var_first_dim_size_ = var_first_dim_size_;
   input_params.var_outer_dim_size_ = var_outer_dim_size_;
-  MultiThreadCompute(ComputeFtrl, &input_params, unique_sparse_grad.indices_size_);
+  MultiThreadCompute<T>(ComputeFtrl<T>, &input_params, unique_sparse_grad.indices_size_);
+}
+
+bool SparseApplyFtrlCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
+                                      const std::vector<kernel::AddressPtr> &workspace,
+                                      const std::vector<kernel::AddressPtr> & /*outputs*/) {
+  if (inputs.size() < kSparseApplyFtrlInputSize) {
+    MS_LOG(EXCEPTION) << "error input output size!";
+  }
+
+  if (indices_data_type_ == kNumberTypeInt32) {
+    LaunchKernel<int>(inputs, workspace);
+  } else {
+    LaunchKernel<int64_t>(inputs, workspace);
+  }
   return true;
 }
 }  // namespace kernel
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h
index a4523e853..bc6054d58 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h
@@ -17,12 +17,11 @@
 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_
 
 #include <vector>
-#include "backend/kernel_compiler/cpu/cpu_kernel.h"
-#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
+#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
 
 namespace mindspore {
 namespace kernel {
-class SparseApplyFtrlCPUKernel : public CPUKernel {
+class SparseApplyFtrlCPUKernel : public SparseOptimizerCPUKernel {
  public:
   SparseApplyFtrlCPUKernel() = default;
   ~SparseApplyFtrlCPUKernel() override = default;
@@ -31,11 +30,13 @@ class SparseApplyFtrlCPUKernel : public CPUKernel {
   void InitInputOutputSize(const CNodePtr &kernel_node) override;
   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs) override;
+  template <typename T>
+  void InitWorkspaceSize();
+  template <typename T>
+  void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                    const std::vector<kernel::AddressPtr> &workspace) const;
 
  protected:
-  size_t indices_size_{0};
-  size_t var_first_dim_size_{0};
-  size_t var_outer_dim_size_{1};
   float lr_{0};
   float l1_{0};
   float l2_{0};
@@ -53,6 +54,18 @@ MS_REG_CPU_KERNEL(FusedSparseFtrl,
                     .AddOutputAttr(kNumberTypeFloat32)
                     .AddOutputAttr(kNumberTypeFloat32),
                   SparseApplyFtrlCPUKernel);
+
+MS_REG_CPU_KERNEL(FusedSparseFtrl,
+                  KernelAttr()
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeInt64)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32),
+                  SparseApplyFtrlCPUKernel);
 }  // namespace kernel
 }  // namespace mindspore
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc
index 24a48e2d7..c0aa704b8 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc
@@ -22,7 +22,8 @@ namespace kernel {
 namespace {
 constexpr size_t kSparseApplyLazyAdamInputSize = 11;
 
-void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeLazyAdam(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto var = input_params->var_;
   auto m = input_params->m_;
@@ -36,8 +37,8 @@ void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_
   const auto var_first_dim_size = input_params->var_first_dim_size_;
   const auto var_outer_dim_size = input_params->var_outer_dim_size_;
   for (size_t i = start; i < end; ++i) {
-    int index = unique_sparse_grad.indices_[i];
-    if (index < 0 || IntToSize(index) >= var_first_dim_size) {
+    T index = unique_sparse_grad.indices_[i];
+    if (index < 0 || LongToSize(index) >= var_first_dim_size) {
       MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range";
     }
     size_t start_index = var_outer_dim_size * index;
@@ -56,13 +57,21 @@ void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_
 }
 }  // namespace
 
-void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
-  CPUKernel::InitInputOutputSize(kernel_node);
-  MS_EXCEPTION_IF_NULL(kernel_node);
+template <typename T>
+void SparseApplyLazyAdamCPUKernel::InitWorkspaceSize() {
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
+}
+
+void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
+  CPUKernel::InitInputOutputSize(kernel_node);
+  if (indices_data_type_ == kNumberTypeInt32) {
+    InitWorkspaceSize<int>();
+  } else {
+    InitWorkspaceSize<int64_t>();
+  }
 }
 
 void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -98,15 +107,12 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
   if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
     use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
   }
+  indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 10);
 }
 
-bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
-                                          const std::vector<kernel::AddressPtr> &workspace,
-                                          const std::vector<kernel::AddressPtr> & /*outputs*/) {
-  if (inputs.size() < kSparseApplyLazyAdamInputSize) {
-    MS_LOG(EXCEPTION) << "Error input size!";
-  }
-
+template <typename T>
+void SparseApplyLazyAdamCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                                                const std::vector<kernel::AddressPtr> &workspace) const {
   auto var = reinterpret_cast<float *>(inputs[0]->addr);
   auto m = reinterpret_cast<float *>(inputs[1]->addr);
   auto v = reinterpret_cast<float *>(inputs[2]->addr);
@@ -120,16 +126,16 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr>
   auto beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
   auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
   auto grad = reinterpret_cast<float *>(inputs[9]->addr);
-  auto indices = reinterpret_cast<int *>(inputs[10]->addr);
+  auto indices = reinterpret_cast<T *>(inputs[10]->addr);
   auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
-  auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
+  auto new_indices = reinterpret_cast<T *>(workspace[1]->addr);
   auto workspace_grad = reinterpret_cast<float *>(workspace[2]->addr);
-  auto workspace_indices = reinterpret_cast<int *>(workspace[3]->addr);
+  auto workspace_indices = reinterpret_cast<T *>(workspace[3]->addr);
 
-  SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
-  SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
-  SparseGradient input_sparse_grad({grad, indices, indices_size_});
-  ReduceSparseGradientParam param;
+  SparseGradient<T> unique_sparse_grad({new_grad, new_indices, indices_size_});
+  SparseGradient<T> workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
+  SparseGradient<T> input_sparse_grad({grad, indices, indices_size_});
+  ReduceSparseGradientParam<T> param;
   param.input_grad_ = &input_sparse_grad;
   param.workspace_grad_ = &workspace_sparse_grad;
   param.output_grad_ = &unique_sparse_grad;
@@ -138,7 +144,7 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr>
   BucketReduceSparseGradient(param);
 
   lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
-  MultiThreadComputeParams input_params;
+  MultiThreadComputeParams<T> input_params;
   input_params.var_ = var;
   input_params.m_ = m;
   input_params.v_ = v;
@@ -150,7 +156,21 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr>
   input_params.sparse_grad_ = unique_sparse_grad;
   input_params.var_first_dim_size_ = var_first_dim_size_;
   input_params.var_outer_dim_size_ = var_outer_dim_size_;
-  MultiThreadCompute(ComputeLazyAdam, &input_params, unique_sparse_grad.indices_size_);
+  MultiThreadCompute<T>(ComputeLazyAdam<T>, &input_params, unique_sparse_grad.indices_size_);
+}
+
+bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
+                                          const std::vector<kernel::AddressPtr> &workspace,
+                                          const std::vector<kernel::AddressPtr> & /*outputs*/) {
+  if (inputs.size() < kSparseApplyLazyAdamInputSize) {
+    MS_LOG(EXCEPTION) << "Error input size!";
+  }
+
+  if (indices_data_type_ == kNumberTypeInt32) {
+    LaunchKernel<int>(inputs, workspace);
+  } else {
+    LaunchKernel<int64_t>(inputs, workspace);
+  }
   return true;
 }
 }  // namespace kernel
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h
index 2235c22ea..46d28dbda 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h
@@ -17,13 +17,11 @@
 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_
 
 #include <vector>
-#include <memory>
-#include "backend/kernel_compiler/cpu/cpu_kernel.h"
-#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
+#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
 
 namespace mindspore {
 namespace kernel {
-class SparseApplyLazyAdamCPUKernel : public CPUKernel {
+class SparseApplyLazyAdamCPUKernel : public SparseOptimizerCPUKernel {
  public:
   SparseApplyLazyAdamCPUKernel() = default;
   ~SparseApplyLazyAdamCPUKernel() override = default;
@@ -32,11 +30,13 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel {
   void InitInputOutputSize(const CNodePtr &kernel_node) override;
   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs) override;
+  template <typename T>
+  void InitWorkspaceSize();
+  template <typename T>
+  void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                    const std::vector<kernel::AddressPtr> &workspace) const;
 
  protected:
-  size_t indices_size_{0};
-  size_t var_first_dim_size_{0};
-  size_t var_outer_dim_size_{1};
   bool use_nesterov_{false};
 };
 
@@ -57,6 +57,24 @@ MS_REG_CPU_KERNEL(FusedSparseLazyAdam,
                     .AddOutputAttr(kNumberTypeFloat32)
                     .AddOutputAttr(kNumberTypeFloat32),
                   SparseApplyLazyAdamCPUKernel);
+
+MS_REG_CPU_KERNEL(FusedSparseLazyAdam,
+                  KernelAttr()
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeInt64)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32),
+                  SparseApplyLazyAdamCPUKernel);
 }  // namespace kernel
 }  // namespace mindspore
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc
index 9e066c587..fe8c27b87 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc
@@ -22,7 +22,8 @@ namespace kernel {
 namespace {
 constexpr size_t kSparseApplyProximalAdagradInputSize = 7;
 
-void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start, size_t end) {
+template <typename T>
+void ComputeProximalAdagrad(MultiThreadComputeParams<T> *input_params, size_t start, size_t end) {
   MS_EXCEPTION_IF_NULL(input_params);
   auto var = input_params->var_;
   auto accum = input_params->accum_;
@@ -33,8 +34,8 @@ void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start
   const auto var_first_dim_size = input_params->var_first_dim_size_;
   const auto var_outer_dim_size = input_params->var_outer_dim_size_;
   for (size_t i = start; i < end; ++i) {
-    int index = unique_sparse_grad.indices_[i];
-    if (index < 0 || IntToSize(index) >= var_first_dim_size) {
+    T index = unique_sparse_grad.indices_[i];
+    if (index < 0 || LongToSize(index) >= var_first_dim_size) {
       MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
     }
     size_t start_index = var_outer_dim_size * index;
@@ -56,13 +57,21 @@ void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start
 }
 }  // namespace
 
-void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
-  CPUKernel::InitInputOutputSize(kernel_node);
-  MS_EXCEPTION_IF_NULL(kernel_node);
+template <typename T>
+void SparseApplyProximalAdagradCPUKernel::InitWorkspaceSize() {
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
   workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
-  workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
+  workspace_size_list_.emplace_back(indices_size_ * sizeof(T));
+}
+
+void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
+  CPUKernel::InitInputOutputSize(kernel_node);
+  if (indices_data_type_ == kNumberTypeInt32) {
+    InitWorkspaceSize<int>();
+  } else {
+    InitWorkspaceSize<int64_t>();
+  }
 }
 
 void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -103,31 +112,28 @@ void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node
   if (!l2_shape.empty()) {
     MS_LOG(EXCEPTION) << "l2 is not a scalar";
   }
+  indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 6);
 }
 
-bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
-                                                 const std::vector<kernel::AddressPtr> &workspace,
-                                                 const std::vector<kernel::AddressPtr> & /*outputs*/) {
-  if (inputs.size() < kSparseApplyProximalAdagradInputSize) {
-    MS_LOG(EXCEPTION) << "Wrong input size!";
-  }
-
+template <typename T>
+void SparseApplyProximalAdagradCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                                                       const std::vector<kernel::AddressPtr> &workspace) const {
   auto var = reinterpret_cast<float *>(inputs[0]->addr);
   auto accum = reinterpret_cast<float *>(inputs[1]->addr);
   auto lr = reinterpret_cast<float *>(inputs[2]->addr)[0];
   auto l1 = reinterpret_cast<float *>(inputs[3]->addr)[0];
   auto l2 = reinterpret_cast<float *>(inputs[4]->addr)[0];
   auto grad = reinterpret_cast<float *>(inputs[5]->addr);
-  auto indices = reinterpret_cast<int *>(inputs[6]->addr);
+  auto indices = reinterpret_cast<T *>(inputs[6]->addr);
   auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
-  auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
+  auto new_indices = reinterpret_cast<T *>(workspace[1]->addr);
   auto workspace_grad = reinterpret_cast<float *>(workspace[2]->addr);
-  auto workspace_indices = reinterpret_cast<int *>(workspace[3]->addr);
+  auto workspace_indices = reinterpret_cast<T *>(workspace[3]->addr);
 
-  SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
-  SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
-  SparseGradient input_sparse_grad({grad, indices, indices_size_});
-  ReduceSparseGradientParam param;
+  SparseGradient<T> unique_sparse_grad({new_grad, new_indices, indices_size_});
+  SparseGradient<T> workspace_sparse_grad({workspace_grad, workspace_indices, indices_size_});
+  SparseGradient<T> input_sparse_grad({grad, indices, indices_size_});
+  ReduceSparseGradientParam<T> param;
   param.input_grad_ = &input_sparse_grad;
   param.workspace_grad_ = &workspace_sparse_grad;
   param.output_grad_ = &unique_sparse_grad;
@@ -135,7 +141,7 @@ bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::Addre
   param.value_stride_ = var_outer_dim_size_;
   BucketReduceSparseGradient(param);
 
-  MultiThreadComputeParams input_params;
+  MultiThreadComputeParams<T> input_params;
   input_params.var_ = var;
   input_params.accum_ = accum;
   input_params.lr_ = lr;
@@ -144,7 +150,20 @@ bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::Addre
   input_params.sparse_grad_ = unique_sparse_grad;
   input_params.var_first_dim_size_ = var_first_dim_size_;
   input_params.var_outer_dim_size_ = var_outer_dim_size_;
-  MultiThreadCompute(ComputeProximalAdagrad, &input_params, unique_sparse_grad.indices_size_);
+  MultiThreadCompute<T>(ComputeProximalAdagrad<T>, &input_params, unique_sparse_grad.indices_size_);
+}
+
+bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
+                                                 const std::vector<kernel::AddressPtr> &workspace,
+                                                 const std::vector<kernel::AddressPtr> & /*outputs*/) {
+  if (inputs.size() < kSparseApplyProximalAdagradInputSize) {
+    MS_LOG(EXCEPTION) << "Wrong input size!";
+  }
+  if (indices_data_type_ == kNumberTypeInt32) {
+    LaunchKernel<int>(inputs, workspace);
+  } else {
+    LaunchKernel<int64_t>(inputs, workspace);
+  }
   return true;
 }
 }  // namespace kernel
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
index 7bd38f556..59e2b359c 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
@@ -17,13 +17,11 @@
 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_
 
 #include <vector>
-#include <memory>
-#include "backend/kernel_compiler/cpu/cpu_kernel.h"
-#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
+#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
 
 namespace mindspore {
 namespace kernel {
-class SparseApplyProximalAdagradCPUKernel : public CPUKernel {
+class SparseApplyProximalAdagradCPUKernel : public SparseOptimizerCPUKernel {
  public:
   SparseApplyProximalAdagradCPUKernel() = default;
   ~SparseApplyProximalAdagradCPUKernel() override = default;
@@ -32,11 +30,11 @@ class SparseApplyProximalAdagradCPUKernel : public CPUKernel {
   void InitInputOutputSize(const CNodePtr &kernel_node) override;
   bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs) override;
-
- private:
-  size_t indices_size_{0};
-  size_t var_first_dim_size_{0};
-  size_t var_outer_dim_size_{1};
+  template <typename T>
+  void InitWorkspaceSize();
+  template <typename T>
+  void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
+                    const std::vector<kernel::AddressPtr> &workspace) const;
 };
 
 MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad,
@@ -51,6 +49,19 @@ MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad,
                     .AddOutputAttr(kNumberTypeFloat32)
                     .AddOutputAttr(kNumberTypeFloat32),
                   SparseApplyProximalAdagradCPUKernel);
+
+MS_REG_CPU_KERNEL(FusedSparseProximalAdagrad,
+                  KernelAttr()
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeFloat32)
+                    .AddInputAttr(kNumberTypeInt64)
+                    .AddOutputAttr(kNumberTypeFloat32)
+                    .AddOutputAttr(kNumberTypeFloat32),
+                  SparseApplyProximalAdagradCPUKernel);
 }  // namespace kernel
 }  // namespace mindspore
 
diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h
new file mode 100644
index 000000000..060218b2d
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h
@@ -0,0 +1,442 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
+#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
+
+#include <vector>
+#include <memory>
+#include <thread>
+#include <unordered_map>
+#include <algorithm>
+#include <utility>
+#include "backend/kernel_compiler/cpu/cpu_kernel.h"
+#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
+
+namespace mindspore {
+namespace kernel {
+template <typename T>
+struct SparseGradient {
+  float *value_{nullptr};
+  T *indices_{nullptr};
+  size_t indices_size_{0};
+};
+
+template <typename T>
+struct ReduceSparseGradientParam {
+  SparseGradient<T> *input_grad_{nullptr};
+  SparseGradient<T> *workspace_grad_{nullptr};
+  SparseGradient<T> *output_grad_{nullptr};
+  size_t max_index_{0};
+  size_t value_stride_{0};
+  bool use_sort_reduce_{false};
+};
+
+template <typename T>
+struct MultiThreadComputeParams {
+  float *var_{nullptr};
+  float *accum_{nullptr};
+  float *linear_{nullptr};
+  float *m_{nullptr};
+  float *m_t_{nullptr};
+  float *v_{nullptr};
+  float lr_{0};
+  float l1_{0};
+  float l2_{0};
+  float lr_power_{0};
+  float beta1_{0};
+  float beta2_{0};
+  float epsilon_{0};
+  SparseGradient<T> sparse_grad_;
+  size_t var_first_dim_size_{0};
+  size_t var_outer_dim_size_{0};
+  bool use_nesterov_;
+};
+template <typename T>
+using MultiThreadComputeFunc = std::function<void(MultiThreadComputeParams<T> *param, size_t start, size_t end)>;
+
+template <typename T>
+struct BucketSparseGradient {
+  float *value_;
+  T *indices_;
+  T *global_indices_;
+  size_t indices_size_;
+};
+
+template <typename T>
+struct MultiThreadReduceSparseGradientParam {
+  SparseGradient<T> *input_grad_{nullptr};
+  SparseGradient<T> *workspace_grad_{nullptr};
+  SparseGradient<T> *output_grad_{nullptr};
+  size_t max_index_{0};
+  size_t value_stride_{0};
+  size_t thread_num_{0};
+  bool use_sort_reduce_{false};
+};
+
+class SparseOptimizerCPUKernel : public CPUKernel {
+ public:
+  SparseOptimizerCPUKernel() = default;
+  ~SparseOptimizerCPUKernel() override = default;
+
+  template <typename T>
+  static void BucketReduceSparseGradient(const ReduceSparseGradientParam<T> &param) {
+    MS_LOG(DEBUG) << "Start";
+    MS_EXCEPTION_IF_NULL(param.input_grad_);
+    size_t thread_num = 23;
+    if (param.input_grad_->indices_size_ < thread_num) {
+      thread_num = param.input_grad_->indices_size_;
+    }
+    MultiThreadReduceSparseGradientParam<T> multi_thread_param(
+      {param.input_grad_, param.workspace_grad_, param.output_grad_, param.max_index_, param.value_stride_, thread_num,
+       param.use_sort_reduce_});
+    std::vector<std::shared_ptr<SparseGradient<T>>> segments;
+    std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes;
+    SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes);
+
+    std::vector<std::shared_ptr<BucketSparseGradient<T>>> buckets;
+    GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets);
+
+    std::vector<std::shared_ptr<SparseGradient<T>>> reduced_buckets;
+    ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets);
+
+    MergeReduceSparseGradient(multi_thread_param, reduced_buckets);
+    MS_LOG(DEBUG) << "End";
+  }
+
+ protected:
+  template <typename T>
+  void MultiThreadCompute(const MultiThreadComputeFunc<T> &func, MultiThreadComputeParams<T> *params,
+                          size_t total_compute_size) const {
+    const size_t kThreadNum = 24;
+    std::vector<std::thread> threads;
+    threads.reserve(kThreadNum);
+    size_t start = 0;
+    size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum;
+    while (start < total_compute_size) {
+      size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
+      threads.emplace_back(std::thread(func, params, start, end));
+      start += once_compute_size;
+    }
+    for (size_t i = 0; i < threads.size(); ++i) {
+      threads[i].join();
+    }
+  }
+
+ private:
+  template <typename T>
+  static void CalculateEachBucketSize(const std::shared_ptr<SparseGradient<T>> &sparse_grad, size_t max_index,
+                                      std::vector<size_t> *each_bucket_size) {
+    MS_LOG(DEBUG) << "Start";
+    MS_EXCEPTION_IF_NULL(sparse_grad);
+    MS_EXCEPTION_IF_NULL(sparse_grad->indices_);
+    MS_EXCEPTION_IF_NULL(each_bucket_size);
+    size_t bucket_num = each_bucket_size->size();
+    for (size_t i = 0; i < sparse_grad->indices_size_; ++i) {
+      T index = sparse_grad->indices_[i];
+      if (index >= 0 && LongToSize(index) < max_index) {
+        auto bucket_id = index % bucket_num;
+        each_bucket_size->at(bucket_id)++;
+      }
+    }
+    MS_LOG(DEBUG) << "End";
+  }
+
+  template <typename T>
+  static void SplitAndCalculateSegmentBucketSize(
+    const MultiThreadReduceSparseGradientParam<T> &param, std::vector<std::shared_ptr<SparseGradient<T>>> *segments_ptr,
+    std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) {
+    MS_EXCEPTION_IF_NULL(param.input_grad_);
+    MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr);
+    MS_EXCEPTION_IF_NULL(segments_ptr);
+    auto &segments = *segments_ptr;
+    auto &segment_bucket_sizes = *segment_bucket_sizes_ptr;
+    auto input_grad = param.input_grad_;
+    if (param.thread_num_ < 1) {
+      MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
+    }
+    size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
+    size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
+    std::vector<std::thread> threads;
+    threads.reserve(param.thread_num_);
+    segments.reserve(param.thread_num_);
+
+    size_t current_indices_offset = 0;
+    for (size_t i = 0; i < param.thread_num_; ++i) {
+      segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0));
+      size_t indices_size = thread_indices_size;
+      if (i < left_indices_size) {
+        indices_size += 1;
+      }
+      segments.emplace_back(std::make_shared<SparseGradient<T>>());
+      segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
+      segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
+      segments[i]->indices_size_ = indices_size;
+      threads.emplace_back(
+        std::thread(CalculateEachBucketSize<T>, segments[i], param.max_index_, segment_bucket_sizes[i].get()));
+      current_indices_offset += indices_size;
+    }
+
+    for (size_t i = 0; i < param.thread_num_; ++i) {
+      threads[i].join();
+    }
+  }
+
+  template <typename T>
+  static void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam<T> &param,
+                                         const std::shared_ptr<SparseGradient<T>> &segment, size_t bucket_offset,
+                                         const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets) {
+    MS_LOG(DEBUG) << "Start";
+    MS_EXCEPTION_IF_NULL(segment);
+    MS_EXCEPTION_IF_NULL(segment->indices_);
+    std::vector<size_t> bucket_data_num(param.thread_num_, 0);
+    for (size_t i = 0; i < segment->indices_size_; ++i) {
+      T index = segment->indices_[i];
+      if (index >= 0 && LongToSize(index) < param.max_index_) {
+        auto bucket_id = index % param.thread_num_;
+        auto bucket_index = bucket_data_num[bucket_id];
+        buckets[bucket_id]->indices_[bucket_index] = index;
+        buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i;
+        bucket_data_num[bucket_id]++;
+      }
+    }
+    MS_LOG(DEBUG) << "End";
+  }
+
+  template <typename T>
+  static void GatherSegmentIndicesToOutputBucket(
+    const MultiThreadReduceSparseGradientParam<T> &param,
+    const std::vector<std::shared_ptr<SparseGradient<T>>> &segments,
+    const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes,
+    std::vector<std::shared_ptr<BucketSparseGradient<T>>> *buckets_ptr) {
+    MS_EXCEPTION_IF_NULL(param.output_grad_);
+    MS_EXCEPTION_IF_NULL(param.output_grad_->value_);
+    MS_EXCEPTION_IF_NULL(param.output_grad_->indices_);
+    MS_EXCEPTION_IF_NULL(buckets_ptr);
+    auto &buckets = *buckets_ptr;
+    size_t thread_num = param.thread_num_;
+    if (thread_num != segment_bucket_sizes.size()) {
+      MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!";
+    }
+    std::vector<size_t> bucket_data_size(thread_num, 0);
+    for (size_t i = 0; i < thread_num; ++i) {
+      for (size_t j = 0; j < thread_num; ++j) {
+        bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
+      }
+    }
+    size_t current_indices_offset = 0;
+    for (size_t i = 0; i < thread_num; ++i) {
+      buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>());
+      buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_;
+      buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset;
+      buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset;
+      buckets[i]->indices_size_ = bucket_data_size[i];
+      current_indices_offset += bucket_data_size[i];
+    }
+    std::vector<size_t> tmp_bucket_data_size(thread_num, 0);
+    std::vector<std::vector<std::shared_ptr<BucketSparseGradient<T>>>> each_thread_buckets;
+    for (size_t i = 0; i < thread_num; ++i) {
+      std::vector<std::shared_ptr<BucketSparseGradient<T>>> thread_buckets;
+      for (size_t j = 0; j < thread_num; ++j) {
+        thread_buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>());
+        thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j];
+        thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j];
+        thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_;
+        thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j);
+        tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
+      }
+      each_thread_buckets.emplace_back(thread_buckets);
+    }
+    std::vector<std::thread> threads;
+    threads.reserve(thread_num);
+    current_indices_offset = 0;
+    for (size_t i = 0; i < thread_num; ++i) {
+      threads.emplace_back(
+        std::thread(CopySegmentIndicesToBucket<T>, param, segments[i], current_indices_offset, each_thread_buckets[i]));
+      current_indices_offset += segments[i]->indices_size_;
+    }
+    for (size_t i = 0; i < thread_num; ++i) {
+      threads[i].join();
+    }
+  }
+
+  template <typename T>
+  static void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
+                                                const std::shared_ptr<BucketSparseGradient<T>> &bucket,
+                                                const std::shared_ptr<SparseGradient<T>> &reduced_bucket) {
+    MS_LOG(DEBUG) << "Start";
+    MS_EXCEPTION_IF_NULL(bucket);
+    MS_EXCEPTION_IF_NULL(bucket->value_);
+    MS_EXCEPTION_IF_NULL(bucket->indices_);
+    MS_EXCEPTION_IF_NULL(reduced_bucket);
+    MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
+    MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
+    std::vector<std::pair<T, T>> sorted_indices;
+    sorted_indices.reserve(bucket->indices_size_);
+    for (size_t i = 0; i < bucket->indices_size_; ++i) {
+      T index = bucket->indices_[i];
+      T global_index = bucket->global_indices_[i];
+      sorted_indices.emplace_back(std::pair<T, T>(index, global_index));
+    }
+    std::sort(sorted_indices.begin(), sorted_indices.end());
+
+    float *global_value = param.input_grad_->value_;
+    size_t unique_indices_size = 0;
+    size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
+    T last_index{0};
+    size_t value_offset{0};
+    for (size_t i = 0; i < sorted_indices.size(); ++i) {
+      T index = sorted_indices[i].first;
+      T global_index = sorted_indices[i].second;
+      T global_value_offset = global_index * param.value_stride_;
+      if (i == 0 || index != last_index) {
+        if (i != 0) {
+          unique_indices_size++;
+        }
+        reduced_bucket->indices_[unique_indices_size] = index;
+        value_offset = unique_indices_size * param.value_stride_;
+        auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float),
+                                 global_value + global_value_offset, param.value_stride_ * sizeof(float));
+        if (ret_code != EOK) {
+          MS_LOG(EXCEPTION) << "Failed to copy data!";
+        }
+      } else {
+        for (size_t j = 0; j < param.value_stride_; ++j) {
+          reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j];
+        }
+      }
+      last_index = index;
+    }
+    reduced_bucket->indices_size_ = unique_indices_size;
+    MS_LOG(DEBUG) << "End";
+  }
+
+  template <typename T>
+  static void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
+                                         const std::shared_ptr<BucketSparseGradient<T>> &bucket,
+                                         const std::shared_ptr<SparseGradient<T>> &reduced_bucket) {
+    MS_LOG(DEBUG) << "Start";
+    MS_EXCEPTION_IF_NULL(bucket);
+    MS_EXCEPTION_IF_NULL(bucket->value_);
+    MS_EXCEPTION_IF_NULL(bucket->indices_);
+    MS_EXCEPTION_IF_NULL(reduced_bucket);
+    MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
+    MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
+
+    float *global_value = param.input_grad_->value_;
+    std::unordered_map<T, size_t> index_map;
+    size_t unique_indices_size = 0;
+    size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
+    for (size_t i = 0; i < bucket->indices_size_; ++i) {
+      T index = bucket->indices_[i];
+      T global_index = bucket->global_indices_[i];
+      auto iter = index_map.find(index);
+      if (iter == index_map.end()) {
+        reduced_bucket->indices_[unique_indices_size] = index;
+        size_t start_index = unique_indices_size * param.value_stride_;
+        index_map[index] = start_index;
+        auto ret_code =
+          memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float),
+                   global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float));
+        if (ret_code != EOK) {
+          MS_LOG(EXCEPTION) << "Failed to copy data!";
+        }
+        unique_indices_size++;
+      } else {
+        size_t start_index = iter->second;
+        size_t end_index = start_index + param.value_stride_;
+        for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) {
+          reduced_bucket->value_[j] += global_value[k];
+        }
+      }
+    }
+    reduced_bucket->indices_size_ = unique_indices_size;
+    MS_LOG(DEBUG) << "End";
+  }
+
+  template <typename T>
+  static void ReduceBucketSparseGradientToWorkspace(
+    const MultiThreadReduceSparseGradientParam<T> &param,
+    const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets,
+    std::vector<std::shared_ptr<SparseGradient<T>>> *reduced_buckets_ptr) {
+    MS_EXCEPTION_IF_NULL(param.workspace_grad_);
+    MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_);
+    MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_);
+    MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
+    auto &reduced_buckets = *reduced_buckets_ptr;
+    size_t thread_num = buckets.size();
+    std::vector<std::thread> threads;
+    threads.reserve(thread_num);
+
+    size_t current_indices_offset = 0;
+    for (size_t i = 0; i < thread_num; ++i) {
+      reduced_buckets.emplace_back(std::make_shared<SparseGradient<T>>());
+      reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
+      reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
+      reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
+      if (param.use_sort_reduce_) {
+        threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i]));
+      } else {
+        threads.emplace_back(std::thread(ReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i]));
+      }
+      current_indices_offset += buckets[i]->indices_size_;
+    }
+    for (size_t i = 0; i < thread_num; ++i) {
+      threads[i].join();
+    }
+  }
+
+  template <typename T>
+  static void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
+                                        const std::vector<std::shared_ptr<SparseGradient<T>>> &reduced_buckets) {
+    MS_EXCEPTION_IF_NULL(param.output_grad_);
+    auto output_grad = param.output_grad_;
+    MS_EXCEPTION_IF_NULL(output_grad->value_);
+    MS_EXCEPTION_IF_NULL(output_grad->indices_);
+    size_t stride_data_size = param.value_stride_ * sizeof(float);
+    size_t unique_indices_size = 0;
+    for (size_t i = 0; i < reduced_buckets.size(); ++i) {
+      auto &bucket = reduced_buckets[i];
+      MS_EXCEPTION_IF_NULL(bucket);
+      if (bucket->indices_size_ == 0) {
+        continue;
+      }
+      auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_,
+                               (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_,
+                               bucket->indices_size_ * stride_data_size);
+      if (ret_code != EOK) {
+        MS_LOG(EXCEPTION) << "Failed to copy data!";
+      }
+      ret_code = memcpy_s(output_grad->indices_ + unique_indices_size,
+                          (output_grad->indices_size_ - unique_indices_size) * sizeof(T), bucket->indices_,
+                          bucket->indices_size_ * sizeof(T));
+      if (ret_code != EOK) {
+        MS_LOG(EXCEPTION) << "Failed to copy data!";
+      }
+      unique_indices_size += bucket->indices_size_;
+    }
+    output_grad->indices_size_ = unique_indices_size;
+  }
+
+ protected:
+  TypeId indices_data_type_{kNumberTypeInt32};
+  size_t indices_size_{0};
+  size_t var_first_dim_size_{0};
+  size_t var_outer_dim_size_{1};
+};
+}  // namespace kernel
+}  // namespace mindspore
+
+#endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc
index aebd12b72..ffe9be054 100644
--- a/mindspore/ccsrc/backend/optimizer/common/helper.cc
+++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc
@@ -374,10 +374,12 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
   }
   ScalarPtr scalar = v->cast<ScalarPtr>();
   MS_EXCEPTION_IF_NULL(scalar);
-  if (scalar->isa<IntergerImm>()) {
-    tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len);
+  if (scalar->isa<Int32Imm>()) {
+    tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, kInt32, sizeof(int32_t));
+  } else if (scalar->isa<Int64Imm>()) {
+    tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, kInt64, sizeof(int64_t));
   } else if (scalar->isa<FloatImm>()) {
-    tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len);
+    tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, sizeof(float));
   } else {
     auto type = scalar->type();
     auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
@@ -698,6 +700,9 @@ ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
   if (utils::isa<int>(sexp)) {
     return NewValueNode(utils::cast<int>(sexp));
   }
+  if (utils::isa<int64_t>(sexp)) {
+    return NewValueNode(utils::cast<int64_t>(sexp));
+  }
   if (utils::isa<float>(sexp)) {
     return NewValueNode(utils::cast<float>(sexp));
   }
diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
index f9a318be2..2ce7f70cb 100644
--- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
+++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
@@ -40,7 +40,6 @@ void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
 
 void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) {
   MS_EXCEPTION_IF_NULL(kernel_graph);
-  size_t type_size = sizeof(float);
   for (auto &item_node : kernel_graph->graph_value_nodes()) {
     MS_EXCEPTION_IF_NULL(item_node);
     if (item_node->isa<ValueNode>()) {
@@ -53,11 +52,23 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
       }
       auto tensor = node_value->cast<TensorPtr>();
       MS_EXCEPTION_IF_NULL(tensor);
+      size_t type_size = sizeof(float);
+      if (tensor->data_type() == kNumberTypeInt64) {
+        type_size = GetTypeByte(TypeIdToType(kNumberTypeInt64));
+      }
       ShapeVector data_shape = tensor->shape();
       size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
-      DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
+      DeviceAddressPtr address = nullptr;
+      if (tensor->data_type() == kNumberTypeInt32) {
+        address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeInt32);
+      } else if (tensor->data_type() == kNumberTypeInt64) {
+        address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeInt64);
+      } else {
+        address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
+      }
       MS_EXCEPTION_IF_NULL(address);
-      if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {
+      if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32 ||
+          tensor->data_type() == kNumberTypeInt64) {
         address->ptr_ = tensor->data_c();
       } else {
         address->ptr_ = resource_manager_.MemMalloc(tensor_size);
@@ -74,14 +85,20 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
 
 void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) {
   MS_EXCEPTION_IF_NULL(kernel_graph);
-  size_t type_size = sizeof(float);
   for (auto &item : kernel_graph->inputs()) {
     MS_EXCEPTION_IF_NULL(item);
     if (item->isa<Parameter>()) {
       auto output_num = AnfAlgo::GetOutputTensorNum(item);
       for (size_t index = 0; index < output_num; index++) {
         TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
+        if (output_type_id == kTypeUnknown) {
+          output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
+        }
         std::vector<size_t> fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index);
+        size_t type_size = sizeof(float);
+        if (output_type_id == kNumberTypeInt64) {
+          type_size = GetTypeByte(TypeIdToType(kNumberTypeInt64));
+        }
         size_t tensor_size =
           fmt_shape.empty() ? type_size
                             : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
@@ -222,7 +239,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
         (void)tensor->data_sync();
       }
       if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 ||
-          tensor->data_type() == kNumberTypeInt32) {
+          tensor->data_type() == kNumberTypeInt32 || tensor->data_type() == kNumberTypeInt64) {
         address->ptr_ = tensor->data_c();
       } else {
         ShapeVector data_shape = tensor->shape();
diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc
index 6431cc0bc..1819f1257 100644
--- a/mindspore/ccsrc/utils/convert_utils.cc
+++ b/mindspore/ccsrc/utils/convert_utils.cc
@@ -638,8 +638,10 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
   tensor::TensorPtr tensor = nullptr;
   if (scalar->isa<FloatImm>()) {
     tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32);
-  } else if (scalar->isa<IntergerImm>()) {
+  } else if (scalar->isa<Int32Imm>()) {
     tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32);
+  } else if (scalar->isa<Int64Imm>()) {
+    tensor = std::make_shared<tensor::Tensor>(GetValue<int64_t>(scalar), kInt64);
   } else if (scalar->isa<BoolImm>()) {
     const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0;
     tensor = std::make_shared<tensor::Tensor>(bool_value, kBool);
diff --git a/tests/ut/cpp/kernel/common_utils_test.cc b/tests/ut/cpp/kernel/cpu/sparse_optimizer_cpu_kernel_test.cc
similarity index 80%
rename from tests/ut/cpp/kernel/common_utils_test.cc
rename to tests/ut/cpp/kernel/cpu/sparse_optimizer_cpu_kernel_test.cc
index 4e016cd49..66e0eb233 100644
--- a/tests/ut/cpp/kernel/common_utils_test.cc
+++ b/tests/ut/cpp/kernel/cpu/sparse_optimizer_cpu_kernel_test.cc
@@ -16,7 +16,7 @@
 
 #include <vector>
 #include "common/common_test.h"
-#include "backend/kernel_compiler/common_utils.h"
+#include "backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h"
 
 namespace mindspore {
 namespace kernel {
@@ -51,17 +51,17 @@ TEST_F(CommonUtilTest, BucketReduceSparseGradient1) {
   std::vector<int> tmp_indices(6);
   std::vector<float> tmp_grad(12);
 
-  SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6});
-  SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
-  SparseGradient input_grad({grad.data(), indices.data(), 6});
+  SparseGradient<int> unique_grad({summed_grad.data(), unique_indices.data(), 6});
+  SparseGradient<int> workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
+  SparseGradient<int> input_grad({grad.data(), indices.data(), 6});
 
-  ReduceSparseGradientParam param;
+  ReduceSparseGradientParam<int> param;
   param.input_grad_ = &input_grad;
   param.workspace_grad_ = &workspace_grad;
   param.output_grad_ = &unique_grad;
   param.max_index_ = 6;
   param.value_stride_ = 2;
-  BucketReduceSparseGradient(param);
+  SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
 
   EXPECT_EQ(unique_grad.indices_size_, 3);
   std::vector<int> expect_indices({0, 1, 3});
@@ -103,17 +103,17 @@ TEST_F(CommonUtilTest, BucketReduceSparseGradient2) {
   std::vector<float> summed_grad(12);
   std::vector<int> tmp_indices(6);
   std::vector<float> tmp_grad(12);
-  SparseGradient unique_grad({summed_grad.data(), unique_indices.data(), 6});
-  SparseGradient workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
-  SparseGradient input_grad({grad.data(), indices.data(), 6});
+  SparseGradient<int> unique_grad({summed_grad.data(), unique_indices.data(), 6});
+  SparseGradient<int> workspace_grad({tmp_grad.data(), tmp_indices.data(), 6});
+  SparseGradient<int> input_grad({grad.data(), indices.data(), 6});
 
-  ReduceSparseGradientParam param;
+  ReduceSparseGradientParam<int> param;
   param.input_grad_ = &input_grad;
   param.workspace_grad_ = &workspace_grad;
   param.output_grad_ = &unique_grad;
   param.max_index_ = 6;
   param.value_stride_ = 2;
-  BucketReduceSparseGradient(param);
+  SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
 
   EXPECT_EQ(unique_grad.indices_size_, 2);
 
-- 
GitLab