From 5c0962acfa36928324cefb856d17ef98773fd963 Mon Sep 17 00:00:00 2001
From: zhaoting <zhaoting23@huawei.com>
Date: Wed, 15 Jul 2020 11:35:34 +0800
Subject: [PATCH] add gpu split and restructure gpu concat

---
 .../gpu/arrays/concatv2_gpu_kernel.h          |  92 ++++++-----
 .../gpu/arrays/split_gpu_kernel.cc            |  31 ++++
 .../gpu/arrays/split_gpu_kernel.h             | 153 ++++++++++++++++++
 .../gpu/cuda_impl/concatv2_impl.cu            | 117 +++++---------
 .../gpu/cuda_impl/concatv2_impl.cuh           |  11 +-
 .../gpu/cuda_impl/split_impl.cu               |  50 ++++++
 .../gpu/cuda_impl/split_impl.cuh              |  24 +++
 tests/st/ops/gpu/test_split.py                |  58 +++++++
 8 files changed, 406 insertions(+), 130 deletions(-)
 create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
 create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
 create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
 create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
 create mode 100644 tests/st/ops/gpu/test_split.py

diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
index 15ccedcae..bae315d1c 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
@@ -18,6 +18,7 @@
 #define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H
 
 #include <vector>
+#include <memory>
 #include "backend/kernel_compiler/gpu/gpu_kernel.h"
 #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
 #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
@@ -27,40 +28,35 @@ namespace kernel {
 template <typename T>
 class ConcatV2GpuFwdKernel : public GpuKernel {
  public:
-  ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {}
+  ConcatV2GpuFwdKernel()
+      : axis_(0),
+        input_num_(1),
+        output_size_(0),
+        all_size_before_axis_(1),
+        all_size_axis_(1),
+        inputs_host_(nullptr),
+        len_axis_(nullptr) {}
   ~ConcatV2GpuFwdKernel() override = default;
   const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
   const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
   const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
 
-  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
+  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
               const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
-    if (inputs.size() == 2) {
-      T *input_0 = GetDeviceAddress<T>(inputs, 0);
-      T *input_1 = GetDeviceAddress<T>(inputs, 1);
-      T *output = GetDeviceAddress<T>(outputs, 0);
-      ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output,
-                   reinterpret_cast<cudaStream_t>(stream_ptr));
-    }
-
-    if (inputs.size() == 3) {
-      T *input_0 = GetDeviceAddress<T>(inputs, 0);
-      T *input_1 = GetDeviceAddress<T>(inputs, 1);
-      T *input_2 = GetDeviceAddress<T>(inputs, 2);
-      T *output = GetDeviceAddress<T>(outputs, 0);
-      ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output,
-                   reinterpret_cast<cudaStream_t>(stream_ptr));
-    }
-
-    if (inputs.size() == 4) {
-      T *input_0 = GetDeviceAddress<T>(inputs, 0);
-      T *input_1 = GetDeviceAddress<T>(inputs, 1);
-      T *input_2 = GetDeviceAddress<T>(inputs, 2);
-      T *input_3 = GetDeviceAddress<T>(inputs, 3);
-      T *output = GetDeviceAddress<T>(outputs, 0);
-      ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output,
-                   reinterpret_cast<cudaStream_t>(stream_ptr));
+    T *output = GetDeviceAddress<T>(outputs, 0);
+    T **inputs_device = GetDeviceAddress<T *>(workspace, 0);
+    int *len_axis_device = GetDeviceAddress<int>(workspace, 1);
+    for (size_t i = 0; i < inputs.size(); i++) {
+      inputs_host_[i] = GetDeviceAddress<T>(inputs, i);
     }
+    CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_,
+                                               cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
+                               "ConcatV2 opt cudaMemcpyAsync inputs failed");
+    CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(len_axis_device, len_axis_.get(), sizeof(int) * input_num_,
+                                               cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
+                               "ConcatV2 opt cudaMemcpyAsync length on axis failed");
+    ConcatKernel(output_size_, input_num_, all_size_before_axis_, all_size_axis_, len_axis_device, inputs_device,
+                 output, reinterpret_cast<cudaStream_t>(stream_ptr));
     return true;
   }
   bool Init(const CNodePtr &kernel_node) override {
@@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
       axis_ += SizeToInt(input_shape.size());
     }
 
-    auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
-    for (size_t i = 0; i < input_num; i++) {
-      auto input_size = sizeof(T);
+    input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node));
+    inputs_host_ = std::make_unique<T *[]>(input_num_);
+    len_axis_ = std::make_unique<int[]>(input_num_);
+    for (int i = 0; i < input_num_; i++) {
+      int input_size = 1;
       auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
       for (size_t j = 0; j < input_shape.size(); j++) {
         input_size *= SizeToInt(input_shape[j]);
-        if (j >= IntToSize(axis_)) {
-          w_[i] *= SizeToInt(input_shape[j]);
-        }
-        input_size_list_.push_back(input_size);
       }
+      input_size_list_.push_back(IntToSize(input_size * sizeof(T)));
+      len_axis_[i] = SizeToInt(input_shape[axis_]);
     }
+    workspace_size_list_.push_back(sizeof(T *) * input_num_);
+    workspace_size_list_.push_back(sizeof(int) * input_num_);
 
     auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
-    output_size_ = sizeof(T);
-    for (size_t i = 0; i < output_shape.size(); i++) {
+    output_size_ = 1;
+    for (int i = 0; i < SizeToInt(output_shape.size()); i++) {
       output_size_ *= output_shape[i];
+      if (i > axis_) {
+        all_size_before_axis_ *= output_shape[i];
+        all_size_axis_ *= output_shape[i];
+      }
+      if (i == axis_) {
+        all_size_before_axis_ *= output_shape[i];
+      }
     }
-    output_size_list_.push_back(output_size_);
+    output_size_list_.push_back(IntToSize(output_size_ * sizeof(T)));
 
     InitSizeLists();
     return true;
@@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
 
  private:
   bool CheckParam(const CNodePtr &kernel_node) {
-    size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
-    if (input_num < 2 || input_num > 4) {
-      MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4.";
-      return false;
-    }
     size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
     if (output_num != 1) {
       MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output.";
@@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
     }
     return true;
   }
-  int w_[4] = {1, 1, 1, 1};
   int axis_;
-  size_t output_size_;
+  int input_num_;
+  int output_size_;
+  int all_size_before_axis_;
+  int all_size_axis_;
+  std::unique_ptr<T *[]> inputs_host_;
+  std::unique_ptr<int[]> len_axis_;
   std::vector<size_t> input_size_list_;
   std::vector<size_t> output_size_list_;
   std::vector<size_t> workspace_size_list_;
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
new file mode 100644
index 000000000..0101f6500
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
@@ -0,0 +1,31 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h"
+
+namespace mindspore {
+namespace kernel {
+MS_REG_GPU_KERNEL_ONE(
+  Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
+  SplitGpuFwdKernel, float)
+MS_REG_GPU_KERNEL_ONE(Split,
+                      KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
+                      SplitGpuFwdKernel, int)
+MS_REG_GPU_KERNEL_ONE(
+  Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
+  SplitGpuFwdKernel, half)
+}  // namespace kernel
+}  // namespace mindspore
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
new file mode 100644
index 000000000..b26c01ee1
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
@@ -0,0 +1,153 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
+#define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
+
+#include <vector>
+#include <memory>
+#include "backend/kernel_compiler/gpu/gpu_kernel.h"
+#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
+#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
+
+namespace mindspore {
+namespace kernel {
+template <typename T>
+class SplitGpuFwdKernel : public GpuKernel {
+ public:
+  SplitGpuFwdKernel()
+      : axis_(0),
+        output_num_(1),
+        input_size_(1),
+        axis_step_(1),
+        all_size_before_axis_(1),
+        all_size_axis_(1),
+        outputs_host_(nullptr) {}
+  ~SplitGpuFwdKernel() override = default;
+  const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
+  const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
+  const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
+
+  bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
+              const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
+    T *input = GetDeviceAddress<T>(inputs, 0);
+    T **outputs_device = GetDeviceAddress<T *>(workspace, 0);
+    for (size_t i = 0; i < outputs.size(); i++) {
+      outputs_host_[i] = GetDeviceAddress<T>(outputs, i);
+    }
+    CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_,
+                                               cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
+                               "Split opt cudaMemcpyAsync outputs failed");
+    SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device,
+                reinterpret_cast<cudaStream_t>(stream_ptr));
+    return true;
+  }
+
+  bool Init(const CNodePtr &kernel_node) override {
+    axis_ = GetAttr<int>(kernel_node, "axis");
+    if (axis_ < 0) {
+      auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
+      axis_ += SizeToInt(input_shape.size());
+    }
+    output_num_ = GetAttr<int>(kernel_node, "output_num");
+
+    if (!CheckParam(kernel_node)) {
+      return false;
+    }
+
+    auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
+    input_size_ = 1;
+    all_size_before_axis_ = 1;
+    all_size_axis_ = 1;
+
+    for (int i = 0; i < SizeToInt(input_shape.size()); i++) {
+      input_size_ *= input_shape[i];
+      if (i > axis_) {
+        all_size_before_axis_ *= input_shape[i];
+        all_size_axis_ *= input_shape[i];
+      }
+      if (i == axis_) {
+        all_size_before_axis_ *= input_shape[i];
+      }
+    }
+    input_size_list_.push_back(IntToSize(input_size_ * sizeof(T)));
+    axis_step_ = input_shape[axis_] / output_num_;
+
+    for (int i = 0; i < output_num_; i++) {
+      size_t output_size = 1;
+      auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
+      for (size_t j = 0; j < output_shape.size(); j++) {
+        output_size *= output_shape[j];
+      }
+      output_size_list_.push_back(output_size * sizeof(T));
+    }
+    workspace_size_list_.push_back(sizeof(T *) * output_num_);
+    InitSizeLists();
+    outputs_host_ = std::make_unique<T *[]>(output_num_);
+    return true;
+  }
+
+ protected:
+  void InitSizeLists() override {}
+
+ private:
+  bool CheckParam(const CNodePtr &kernel_node) {
+    auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
+    auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
+    int dims = SizeToInt(input_shape.size());
+    int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node));
+
+    if (input_num != 1) {
+      MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input.";
+      return false;
+    }
+    if (dims == 0) {
+      MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported.";
+      return false;
+    }
+    if (axis_ < -dims || axis_ >= dims) {
+      MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims;
+      return false;
+    }
+    if (output_num_ > SizeToInt(input_shape[axis_])) {
+      MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_];
+      return false;
+    }
+    if (input_shape[axis_] % output_num_ != 0) {
+      MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_];
+      return false;
+    }
+    if (output_num_ != output_num) {
+      MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_;
+      return false;
+    }
+    return true;
+  }
+  int axis_;
+  int output_num_;
+  int input_size_;
+  int axis_step_;
+  int all_size_before_axis_;
+  int all_size_axis_;
+  std::unique_ptr<T *[]> outputs_host_;
+  std::vector<size_t> input_size_list_;
+  std::vector<size_t> output_size_list_;
+  std::vector<size_t> workspace_size_list_;
+};
+}  // namespace kernel
+}  // namespace mindspore
+
+#endif  // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
index 147782591..c3a77d186 100755
--- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
@@ -19,90 +19,51 @@
 #include <cuda_runtime.h>
 #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
 template <typename T>
-__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) {
-  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
-    int n = pos / (w1 + w2);
-    int m = pos % (w1 + w2);
-    output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m];
+__global__ void Concat(const int size, const int input_num,
+                       const int all_size_before_axis, const int all_size_axis,
+                       int* len_axis, T** inputs, T* output) {
+  for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
+    int num = pos % all_size_before_axis / all_size_axis;
+    int block = -1;
+    int axis_inc = 0;
+    int block_len = 0;
+    for (int i = 0; i < input_num; i++) {
+      if (axis_inc <= num) {
+        block++;
+        axis_inc += len_axis[i];
+      } else {
+        break;
+      }
+    }
+    block_len = len_axis[block];
+    axis_inc -= len_axis[block];
+    int block_pos = pos / all_size_before_axis * block_len * all_size_axis +
+                    (num - axis_inc) * all_size_axis + pos % all_size_axis;;
+    output[pos] = inputs[block][block_pos];
   }
   return;
 }
 
 template <typename T>
-__global__ void Concat(const size_t size, const int w1, const int w2, const int w3,
-                       const T* input_1, const T* input_2, const T* input_3, T* output) {
-  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
-    int n = pos / (w1 + w2 + w3);
-    int m = pos % (w1 + w2 + w3);
-    output[pos] = m < w1 ? input_1[n * w1 + m] :
-                    m < w1 + w2 ? input_2[n * w2 + m - w1] :
-                      input_3[n * w3 + m - w1 - w2];
-  }
-  return;
-}
-
-template <typename T>
-__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                       const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) {
-  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
-    int n = pos / (w1 + w2 + w3 + w4);
-    int m = pos % (w1 + w2 + w3 + w4);
-    output[pos] = m < w1 ? input_1[n * w1 + m] :
-                    m < w1 + w2 ? input_2[n * w2 + m - w1]:
-                      m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]:
-                        input_4[n * w4 + m - w1 - w2 - w3];
-  }
-  return;
-}
-
-template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
-                 cudaStream_t cuda_stream) {
-  Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output);
-  return;
-}
-
-template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
-                  const T* input_1, const T* input_2, const T* input_3, T* output,
+void ConcatKernel(const int size, const int input_num,
+                  const int all_size_before_axis, const int all_size_axis,
+                  int* len_axis, T** inputs, T* output,
                   cudaStream_t cuda_stream) {
-  Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output);
+  Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num,
+                                                            all_size_before_axis, all_size_axis,
+                                                            len_axis, inputs, output);
   return;
 }
 
-template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                  const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
-                  cudaStream_t cuda_stream) {
-  Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1,
-                                                            input_2, input_3, input_4, output);
-  return;
-}
-
-template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2,
-                           float* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2,
-                           int* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2,
-                           half* output, cudaStream_t cuda_stream);
-
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
-                           const float* input_1, const float* input_2, const float* input_3,
-                           float* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
-                           const int* input_1, const int* input_2, const int* input_3,
-                           int* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
-                           const half* input_1, const half* input_2, const half* input_3,
-                           half* output, cudaStream_t cuda_stream);
-
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                           const float* input_1, const float* input_2, const float* input_3, const float* input_4,
-                           float* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                           const int* input_1, const int* input_2, const int* input_3, const int* input_4,
-                           int* output, cudaStream_t cuda_stream);
-template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                           const half* input_1, const half* input_2, const half* input_3, const half* input_4,
-                           half* output, cudaStream_t cuda_stream);
-
+template void ConcatKernel(const int size, const int input_num,
+                           const int all_size_before_axis, const int all_size_axis,
+                           int* len_axis, float** inputs, float* output,
+                           cudaStream_t cuda_stream);
+template void ConcatKernel(const int size, const int input_num,
+                           const int all_size_before_axis, const int all_size_axis,
+                           int* len_axis, int** inputs, int* output,
+                           cudaStream_t cuda_stream);
+template void ConcatKernel(const int size, const int input_num,
+                           const int all_size_before_axis, const int all_size_axis,
+                           int* len_axis, half** inputs, half* output,
+                           cudaStream_t cuda_stream);
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
index 7bd32c140..010e2977e 100755
--- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
@@ -19,13 +19,8 @@
 
 #include "runtime/device/gpu/cuda_common.h"
 template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
-                  cudaStream_t cuda_stream);
-template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
-                  const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream);
-template <typename T>
-void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
-                  const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
+void ConcatKernel(const int size, const int input_num,
+                  const int all_size_before_axis, const int all_size_axis,
+                  int* len_axis, T** inputs, T* output,
                   cudaStream_t cuda_stream);
 #endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
new file mode 100755
index 000000000..a24229086
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
@@ -0,0 +1,50 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+#include <stdint.h>
+#include <cuda_runtime.h>
+#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
+template <typename T>
+__global__ void Split(const int size, const int axis_step, const int all_size_before_axis,
+                      const int all_size_axis, const T* input, T** outputs) {
+  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
+    int num = pos % all_size_before_axis / all_size_axis;
+    int block = num / axis_step;
+    int block_pos = pos / all_size_before_axis * axis_step * all_size_axis +
+                    num % axis_step * all_size_axis + pos % all_size_axis;
+    outputs[block][block_pos] = input[pos];
+  }
+  return;
+}
+
+template <typename T>
+void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
+                 const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) {
+  Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis,
+                                                           all_size_axis, input, outputs);
+  return;
+}
+
+template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
+                          const int all_size_axis, const float* input, float** outputs,
+                          cudaStream_t cuda_stream);
+template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
+                          const int all_size_axis, const int* input, int** outputs,
+                          cudaStream_t cuda_stream);
+template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
+                          const int all_size_axis, const half* input, half** outputs,
+                          cudaStream_t cuda_stream);
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
new file mode 100755
index 000000000..5306648da
--- /dev/null
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
@@ -0,0 +1,24 @@
+/**
+ * Copyright 2020 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
+#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
+
+#include "runtime/device/gpu/cuda_common.h"
+template <typename T>
+void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
+                 const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream);
+#endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
diff --git a/tests/st/ops/gpu/test_split.py b/tests/st/ops/gpu/test_split.py
new file mode 100644
index 000000000..f9e3cfce2
--- /dev/null
+++ b/tests/st/ops/gpu/test_split.py
@@ -0,0 +1,58 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import numpy as np
+import pytest
+
+import mindspore.context as context
+from mindspore import Tensor
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+
+
+class Net(nn.Cell):
+    def __init__(self, axis=0, out_nums=1):
+        super(Net, self).__init__()
+        self.split = P.Split(axis, out_nums)
+
+    def construct(self, x):
+        return self.split(x)
+
+
+context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_split():
+    x = np.array([[[1, -1, 1], [2, -2, 2]],
+                  [[3, -3, 3], [4, -4, 4]],
+                  [[5, -5, 5], [6, -6, 6]]]).astype(np.float32)
+
+    split_op = Net(0, 3)
+    outputs = split_op(Tensor(x))
+    for i, out in enumerate(outputs):
+        assert (out.asnumpy() == x[i]).all()
+
+
+def test_split_4d():
+    x_np = np.random.randn(2, 6, 4, 4).astype(np.float32)
+    y = np.split(x_np, 3, axis=1)
+
+    split_op = Net(1, 3)
+    outputs = split_op(Tensor(x_np))
+
+    for i, out in enumerate(outputs):
+        assert (out.asnumpy() == y[i]).all()
-- 
GitLab