From 5c0962acfa36928324cefb856d17ef98773fd963 Mon Sep 17 00:00:00 2001 From: zhaoting <zhaoting23@huawei.com> Date: Wed, 15 Jul 2020 11:35:34 +0800 Subject: [PATCH] add gpu split and restructure gpu concat --- .../gpu/arrays/concatv2_gpu_kernel.h | 92 ++++++----- .../gpu/arrays/split_gpu_kernel.cc | 31 ++++ .../gpu/arrays/split_gpu_kernel.h | 153 ++++++++++++++++++ .../gpu/cuda_impl/concatv2_impl.cu | 117 +++++--------- .../gpu/cuda_impl/concatv2_impl.cuh | 11 +- .../gpu/cuda_impl/split_impl.cu | 50 ++++++ .../gpu/cuda_impl/split_impl.cuh | 24 +++ tests/st/ops/gpu/test_split.py | 58 +++++++ 8 files changed, 406 insertions(+), 130 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh create mode 100644 tests/st/ops/gpu/test_split.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h index 15ccedcae..bae315d1c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H #include <vector> +#include <memory> #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" @@ -27,40 +28,35 @@ namespace kernel { template <typename T> class ConcatV2GpuFwdKernel : public GpuKernel { public: - ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} + ConcatV2GpuFwdKernel() + : axis_(0), + input_num_(1), + output_size_(0), + all_size_before_axis_(1), + all_size_axis_(1), + inputs_host_(nullptr), + len_axis_(nullptr) {} ~ConcatV2GpuFwdKernel() override = default; const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, + bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs, void *stream_ptr) override { - if (inputs.size() == 2) { - T *input_0 = GetDeviceAddress<T>(inputs, 0); - T *input_1 = GetDeviceAddress<T>(inputs, 1); - T *output = GetDeviceAddress<T>(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, - reinterpret_cast<cudaStream_t>(stream_ptr)); - } - - if (inputs.size() == 3) { - T *input_0 = GetDeviceAddress<T>(inputs, 0); - T *input_1 = GetDeviceAddress<T>(inputs, 1); - T *input_2 = GetDeviceAddress<T>(inputs, 2); - T *output = GetDeviceAddress<T>(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, - reinterpret_cast<cudaStream_t>(stream_ptr)); - } - - if (inputs.size() == 4) { - T *input_0 = GetDeviceAddress<T>(inputs, 0); - T *input_1 = GetDeviceAddress<T>(inputs, 1); - T *input_2 = GetDeviceAddress<T>(inputs, 2); - T *input_3 = GetDeviceAddress<T>(inputs, 3); - T *output = GetDeviceAddress<T>(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, - reinterpret_cast<cudaStream_t>(stream_ptr)); + T *output = GetDeviceAddress<T>(outputs, 0); + T **inputs_device = GetDeviceAddress<T *>(workspace, 0); + int *len_axis_device = GetDeviceAddress<int>(workspace, 1); + for (size_t i = 0; i < inputs.size(); i++) { + inputs_host_[i] = GetDeviceAddress<T>(inputs, i); } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_, + cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), + "ConcatV2 opt cudaMemcpyAsync inputs failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(len_axis_device, len_axis_.get(), sizeof(int) * input_num_, + cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), + "ConcatV2 opt cudaMemcpyAsync length on axis failed"); + ConcatKernel(output_size_, input_num_, all_size_before_axis_, all_size_axis_, len_axis_device, inputs_device, + output, reinterpret_cast<cudaStream_t>(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel { axis_ += SizeToInt(input_shape.size()); } - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_size = sizeof(T); + input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node)); + inputs_host_ = std::make_unique<T *[]>(input_num_); + len_axis_ = std::make_unique<int[]>(input_num_); + for (int i = 0; i < input_num_; i++) { + int input_size = 1; auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); for (size_t j = 0; j < input_shape.size(); j++) { input_size *= SizeToInt(input_shape[j]); - if (j >= IntToSize(axis_)) { - w_[i] *= SizeToInt(input_shape[j]); - } - input_size_list_.push_back(input_size); } + input_size_list_.push_back(IntToSize(input_size * sizeof(T))); + len_axis_[i] = SizeToInt(input_shape[axis_]); } + workspace_size_list_.push_back(sizeof(T *) * input_num_); + workspace_size_list_.push_back(sizeof(int) * input_num_); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - output_size_ = sizeof(T); - for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ = 1; + for (int i = 0; i < SizeToInt(output_shape.size()); i++) { output_size_ *= output_shape[i]; + if (i > axis_) { + all_size_before_axis_ *= output_shape[i]; + all_size_axis_ *= output_shape[i]; + } + if (i == axis_) { + all_size_before_axis_ *= output_shape[i]; + } } - output_size_list_.push_back(output_size_); + output_size_list_.push_back(IntToSize(output_size_ * sizeof(T))); InitSizeLists(); return true; @@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel { private: bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num < 2 || input_num > 4) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; - return false; - } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; @@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel { } return true; } - int w_[4] = {1, 1, 1, 1}; int axis_; - size_t output_size_; + int input_num_; + int output_size_; + int all_size_before_axis_; + int all_size_axis_; + std::unique_ptr<T *[]> inputs_host_; + std::unique_ptr<int[]> len_axis_; std::vector<size_t> input_size_list_; std::vector<size_t> output_size_list_; std::vector<size_t> workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc new file mode 100644 index 000000000..0101f6500 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SplitGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Split, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SplitGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SplitGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h new file mode 100644 index 000000000..b26c01ee1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h @@ -0,0 +1,153 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H + +#include <vector> +#include <memory> +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" + +namespace mindspore { +namespace kernel { +template <typename T> +class SplitGpuFwdKernel : public GpuKernel { + public: + SplitGpuFwdKernel() + : axis_(0), + output_num_(1), + input_size_(1), + axis_step_(1), + all_size_before_axis_(1), + all_size_axis_(1), + outputs_host_(nullptr) {} + ~SplitGpuFwdKernel() override = default; + const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } + const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } + const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, + const std::vector<AddressPtr> &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress<T>(inputs, 0); + T **outputs_device = GetDeviceAddress<T *>(workspace, 0); + for (size_t i = 0; i < outputs.size(); i++) { + outputs_host_[i] = GetDeviceAddress<T>(outputs, i); + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_, + cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), + "Split opt cudaMemcpyAsync outputs failed"); + SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device, + reinterpret_cast<cudaStream_t>(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + axis_ = GetAttr<int>(kernel_node, "axis"); + if (axis_ < 0) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ += SizeToInt(input_shape.size()); + } + output_num_ = GetAttr<int>(kernel_node, "output_num"); + + if (!CheckParam(kernel_node)) { + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + all_size_before_axis_ = 1; + all_size_axis_ = 1; + + for (int i = 0; i < SizeToInt(input_shape.size()); i++) { + input_size_ *= input_shape[i]; + if (i > axis_) { + all_size_before_axis_ *= input_shape[i]; + all_size_axis_ *= input_shape[i]; + } + if (i == axis_) { + all_size_before_axis_ *= input_shape[i]; + } + } + input_size_list_.push_back(IntToSize(input_size_ * sizeof(T))); + axis_step_ = input_shape[axis_] / output_num_; + + for (int i = 0; i < output_num_; i++) { + size_t output_size = 1; + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i); + for (size_t j = 0; j < output_shape.size(); j++) { + output_size *= output_shape[j]; + } + output_size_list_.push_back(output_size * sizeof(T)); + } + workspace_size_list_.push_back(sizeof(T *) * output_num_); + InitSizeLists(); + outputs_host_ = std::make_unique<T *[]>(output_num_); + return true; + } + + protected: + void InitSizeLists() override {} + + private: + bool CheckParam(const CNodePtr &kernel_node) { + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + int dims = SizeToInt(input_shape.size()); + int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node)); + + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input."; + return false; + } + if (dims == 0) { + MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported."; + return false; + } + if (axis_ < -dims || axis_ >= dims) { + MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims; + return false; + } + if (output_num_ > SizeToInt(input_shape[axis_])) { + MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; + return false; + } + if (input_shape[axis_] % output_num_ != 0) { + MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; + return false; + } + if (output_num_ != output_num) { + MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; + return false; + } + return true; + } + int axis_; + int output_num_; + int input_size_; + int axis_step_; + int all_size_before_axis_; + int all_size_axis_; + std::unique_ptr<T *[]> outputs_host_; + std::vector<size_t> input_size_list_; + std::vector<size_t> output_size_list_; + std::vector<size_t> workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu index 147782591..c3a77d186 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -19,90 +19,51 @@ #include <cuda_runtime.h> #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" template <typename T> -__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2); - int m = pos % (w1 + w2); - output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; +__global__ void Concat(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int num = pos % all_size_before_axis / all_size_axis; + int block = -1; + int axis_inc = 0; + int block_len = 0; + for (int i = 0; i < input_num; i++) { + if (axis_inc <= num) { + block++; + axis_inc += len_axis[i]; + } else { + break; + } + } + block_len = len_axis[block]; + axis_inc -= len_axis[block]; + int block_pos = pos / all_size_before_axis * block_len * all_size_axis + + (num - axis_inc) * all_size_axis + pos % all_size_axis;; + output[pos] = inputs[block][block_pos]; } return; } template <typename T> -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3); - int m = pos % (w1 + w2 + w3); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1] : - input_3[n * w3 + m - w1 - w2]; - } - return; -} - -template <typename T> -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3 + w4); - int m = pos % (w1 + w2 + w3 + w4); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1]: - m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: - input_4[n * w4 + m - w1 - w2 - w3]; - } - return; -} - -template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream) { - Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output); - return; -} - -template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, +void ConcatKernel(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output, cudaStream_t cuda_stream) { - Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output); + Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, + all_size_before_axis, all_size_axis, + len_axis, inputs, output); return; } -template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream) { - Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1, - input_2, input_3, input_4, output); - return; -} - -template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const float* input_1, const float* input_2, const float* input_3, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const int* input_1, const int* input_2, const int* input_3, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const half* input_1, const half* input_2, const half* input_3, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const float* input_1, const float* input_2, const float* input_3, const float* input_4, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const int* input_1, const int* input_2, const int* input_3, const int* input_4, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const half* input_1, const half* input_2, const half* input_3, const half* input_4, - half* output, cudaStream_t cuda_stream); - +template void ConcatKernel(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, float** inputs, float* output, + cudaStream_t cuda_stream); +template void ConcatKernel(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, int** inputs, int* output, + cudaStream_t cuda_stream); +template void ConcatKernel(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, half** inputs, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh index 7bd32c140..010e2977e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh @@ -19,13 +19,8 @@ #include "runtime/device/gpu/cuda_common.h" template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream); -template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); -template <typename T> -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, +void ConcatKernel(const int size, const int input_num, + const int all_size_before_axis, const int all_size_axis, + int* len_axis, T** inputs, T* output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu new file mode 100755 index 000000000..a24229086 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdint.h> +#include <cuda_runtime.h> +#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" +template <typename T> +__global__ void Split(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int num = pos % all_size_before_axis / all_size_axis; + int block = num / axis_step; + int block_pos = pos / all_size_before_axis * axis_step * all_size_axis + + num % axis_step * all_size_axis + pos % all_size_axis; + outputs[block][block_pos] = input[pos]; + } + return; +} + +template <typename T> +void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) { + Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis, + all_size_axis, input, outputs); + return; +} + +template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const float* input, float** outputs, + cudaStream_t cuda_stream); +template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const int* input, int** outputs, + cudaStream_t cuda_stream); +template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const half* input, half** outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh new file mode 100755 index 000000000..5306648da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template <typename T> +void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, + const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ diff --git a/tests/st/ops/gpu/test_split.py b/tests/st/ops/gpu/test_split.py new file mode 100644 index 000000000..f9e3cfce2 --- /dev/null +++ b/tests/st/ops/gpu/test_split.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, axis=0, out_nums=1): + super(Net, self).__init__() + self.split = P.Split(axis, out_nums) + + def construct(self, x): + return self.split(x) + + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_split(): + x = np.array([[[1, -1, 1], [2, -2, 2]], + [[3, -3, 3], [4, -4, 4]], + [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) + + split_op = Net(0, 3) + outputs = split_op(Tensor(x)) + for i, out in enumerate(outputs): + assert (out.asnumpy() == x[i]).all() + + +def test_split_4d(): + x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) + y = np.split(x_np, 3, axis=1) + + split_op = Net(1, 3) + outputs = split_op(Tensor(x_np)) + + for i, out in enumerate(outputs): + assert (out.asnumpy() == y[i]).all() -- GitLab