From fb405ee6f4ec553d9737f336e84fffb1d6068ae7 Mon Sep 17 00:00:00 2001 From: linqingke Date: Thu, 6 Aug 2020 17:13:34 +0800 Subject: [PATCH] broadcast, slice, scatter_nd ops optimizer. --- .../gpu/arrays/array_reduce_gpu_kernel.h | 57 +++++-- .../gpu/arrays/scatter_nd_gpu_kernel.h | 4 + .../gpu/cuda_impl/broadcast_impl.cu | 153 ++++++++++-------- .../gpu/cuda_impl/broadcast_impl.cuh | 7 +- .../gpu/cuda_impl/check_valid_impl.cu | 10 +- .../kernel_compiler/gpu/cuda_impl/iou_impl.cu | 37 ++--- .../backend/kernel_compiler/gpu/gpu_kernel.h | 34 ++++ .../gpu/math/broadcast_gpu_kernel.h | 19 ++- .../gpu/nn/activation_gpu_kernel.h | 15 +- .../gpu/nn/activation_grad_kernel.h | 14 +- .../gpu/other/check_valid_gpu_kernel.cc | 4 + .../gpu/other/iou_gpu_kernel.cc | 3 + tests/st/ops/gpu/test_floordiv_op.py | 4 +- tests/st/ops/gpu/test_reduce_mean_op.py | 2 +- 14 files changed, 240 insertions(+), 123 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index 0f3251995..8a273965e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -182,30 +182,59 @@ class ArrayReduceGpuKernel : public GpuKernel { void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { std::vector inputA; std::vector outputC_shape = output_shape; - ShapeNdTo4d(input_shape, &inputA); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], - inputA[1], inputA[2], inputA[3]), - "cudnnSetTensor4dDescriptor failed"); + const int split_dim = 4; + + if (input_shape.size() <= split_dim) { + ShapeNdTo4d(input_shape, &inputA); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, + inputA[0], inputA[1], inputA[2], inputA[3]), + "cudnnSetTensor4dDescriptor failed"); + } else { + CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_); + for (auto dim : input_shape) { + inputA.emplace_back(SizeToInt(dim)); + } + } if (axis_[0] == -1) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), - "cudnnSetTensor4dDescriptor failed"); - if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { - all_match_ = true; + outputC_shape.resize(input_shape.size(), 1); + if (outputC_shape.size() <= split_dim) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), + "cudnnSetTensor4dDescriptor failed"); + } else { + CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_); } + + for (auto dim : inputA) { + if (dim != 1) { + return; + } + } + + all_match_ = true; return; } + + std::vector outputC; if (!keep_dims_) { for (auto i : axis_) { (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); } } - std::vector outputC; - ShapeNdTo4d(outputC_shape, &outputC); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - outputC[0], outputC[1], outputC[2], outputC[3]), - "cudnnSetTensor4dDescriptor failed"); + + if (outputC_shape.size() <= split_dim) { + ShapeNdTo4d(outputC_shape, &outputC); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, + outputC[0], outputC[1], outputC[2], outputC[3]), + "cudnnSetTensor4dDescriptor failed"); + } else { + CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_); + for (auto dim : outputC_shape) { + outputC.emplace_back(SizeToInt(dim)); + } + } + if (inputA == outputC) { all_match_ = true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h index 7cc0d1f85..9bd4d1f18 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h @@ -69,6 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel { memcpy_flag_ = true; } + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemsetAsync(output, static_cast(0.0), output_size_, reinterpret_cast(stream_ptr)), + "cudaMemSet failed in ScatterNdGpuFwdKernel::Launch."); + const size_t input_size = input_size_ / sizeof(T); const size_t output_size = output_size_ / sizeof(T); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 827bec11f..52d8ab740 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" #include "runtime/device/gpu/cuda_common.h" @@ -107,69 +108,97 @@ __device__ __forceinline__ int Index(const int &index, const int &dim) { return template __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &r0, const int &r1, const int &r2, const int &r3, - const int &d0, const int &d1, const int &d2, const int &d3, - const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3) % d0; - int j = pos / (d2 * d3) % d1; - int k = pos / d3 % d2; - int l = pos % d3; + const int &l4, const int &l5, const int &l6, const int &r0, + const int &r1, const int &r2, const int &r3, const int &r4, + const int &r5, const int &r6, const int &d0, const int &d1, + const int &d2, const int &d3, const int &d4, const int &d5, + const int &d6, const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; + pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; + int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; + int k = pos / (d3 * d4 * d5 * d6) % d2; + int l = pos / (d4 * d5 * d6) % d3; + int m = pos / (d5 * d6) % d4; + int n = pos / d6 % d5; + int o = pos % d6; - int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); - int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; + l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; + l_index += Index(k, l2) * l3 * l4 * l5 * l6; + l_index += Index(l, l3) * l4 * l5 * l6; + l_index += Index(m, l4) * l5 * l6; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; + r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; + r_index += Index(k, r2) * r3 * r4 * r5 * r6; + r_index += Index(l, r3) * r4 * r5 * r6; + r_index += Index(m, r4) * r5 * r6; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); output[pos] = Func()(input0[l_index], input1[r_index]); } } template -__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, - const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, - enum BroadcastOpType op, const T *input0, const T *input1, S *output) { +__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, + const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, + const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, + const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0, + const T *input1, S *output) { switch (op) { case BROADCAST_TYPE_GREATER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_LESS: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, + d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_MINIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_MAXIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_POWER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_REALDIV: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_MUL: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, + d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_SUB: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, + d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_ADD: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, + d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_FLOORDIV: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); case BROADCAST_TYPE_ABSGRAD: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, + d2, d3, d4, d5, d6, input0, input1, output); } } template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream) { - int size = d0 * d1 * d2 * d3; - BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, - input0, input1, output); +void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const T *input0, const T *input1, + S *output, cudaStream_t stream) { + int size = 1; + for (auto d : output_shape) { + size *= d; + } + BroadcastKernel<<>>(lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3], + lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0], + rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4], + rhs_shape[5], rhs_shape[6], output_shape[0], + output_shape[1], output_shape[2], output_shape[3], + output_shape[4], output_shape[5], output_shape[6], + op, input0, input1, output); } template @@ -236,30 +265,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con output_addr); } -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, float *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, half *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const int *input0, const int *input1, int *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const int *input0, const int *input1, bool *output, - cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const float *input0, + const float *input1, bool *output, cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const float *input0, + const float *input1, float *output, cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const half *input0, + const half *input1, bool *output, cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const half *input0, + const half *input1, half *output, cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const int *input0, + const int *input1, int *output, cudaStream_t stream); +template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const int *input0, + const int *input1, bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index 7d762c34d..b0d80ed04 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ +#include #include "runtime/device/gpu/cuda_common.h" enum BroadcastOpType { @@ -35,9 +36,9 @@ enum BroadcastOpType { }; template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream); +void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, + const std::vector &output_shape, enum BroadcastOpType op, const T *input0, const T *input1, + S *output, cudaStream_t stream); template void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu index 588f8c60e..b45d2749a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu @@ -25,10 +25,10 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m const size_t right_y = i * 4 + 3; S valid_flag = false; - valid_flag |= !(box[left_x] >= 0.f); - valid_flag |= !(box[left_y] >= 0.f); - valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]); - valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]); + valid_flag |= !(box[left_x] >= static_cast(0.0)); + valid_flag |= !(box[left_y] >= static_cast(0.0)); + valid_flag |= !(img_metas[1] * img_metas[2] - static_cast(1.0) >= box[right_x]); + valid_flag |= !(img_metas[0] * img_metas[2] - static_cast(1.0) >= box[right_y]); valid[i] = !valid_flag; } @@ -43,3 +43,5 @@ void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid, cudaStream_t cuda_stream); +template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu index a3cdd7e13..3b785e265 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu @@ -16,27 +16,26 @@ #include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" -template -__device__ T CoordinateMax(const T a, const T b) { +__device__ float CoordinateMax(const float a, const float b) { return (a > b ? a : b); } -template -__device__ T CoordinateMin(const T a, const T b) { +__device__ float CoordinateMin(const float a, const float b) { return (a < b ? a : b); } template __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode, const size_t input_len_0) { - T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION]; - T overlaps_coordinate[IOU_DIMENSION]; - const T epsilon = 1e-10; + float location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION]; + float overlaps_coordinate[IOU_DIMENSION]; + const float epsilon = 1e-10; + const float offset = 1.0; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t j = 0; j < IOU_DIMENSION; j++) { - location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j]; - location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j]; + location_coordinate[0][j] = static_cast(box1[(i % input_len_0) * IOU_DIMENSION + j]); + location_coordinate[1][j] = static_cast(box2[(i / input_len_0) * IOU_DIMENSION + j]); } overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]); @@ -44,18 +43,18 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]); overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]); - T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1); - T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1); - T overlaps = overlaps_w * overlaps_h; + float overlaps_w = CoordinateMax(0.0, overlaps_coordinate[2] - overlaps_coordinate[0] + offset); + float overlaps_h = CoordinateMax(0.0, overlaps_coordinate[3] - overlaps_coordinate[1] + offset); + float overlaps = overlaps_w * overlaps_h; - T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] - - location_coordinate[0][1] + 1); - T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - - location_coordinate[1][1] + 1); + float area1 = (location_coordinate[0][2] - location_coordinate[0][0] + offset) * (location_coordinate[0][3] - + location_coordinate[0][1] + offset); + float area2 = (location_coordinate[1][2] - location_coordinate[1][0] + offset) * (location_coordinate[1][3] - + location_coordinate[1][1] + offset); if (mode == 0) { - iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon); + iou_results[i] = static_cast(overlaps / (area1 + area2 - overlaps + epsilon)); } else { - iou_results[i] = overlaps / (area2 + epsilon); + iou_results[i] = static_cast(overlaps / (area2 + epsilon)); } } @@ -70,3 +69,5 @@ void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode, const size_t &input_len_0, cudaStream_t cuda_stream); +template void IOU(const size_t &size, const half *box1, const half *box2, half *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index a6a25096f..9ee6ead1c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -84,6 +84,40 @@ class GpuKernel : public KernelMod { } } + // set the tensor descriptor for cudnn/cublas + void CudnnSetTensorNdDescriptor(const std::vector &shape, cudnnTensorDescriptor_t descriptor, + cudnnDataType_t data_type) { + if (shape.size() < 3) { + MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D."; + } + const int nbDims = shape.size(); + int *dim = new (std::nothrow) int[nbDims]; + if (dim == nullptr) { + MS_LOG(EXCEPTION) << "malloc dim failed."; + } + int *stride = new (std::nothrow) int[nbDims]; + if (stride == nullptr) { + MS_LOG(EXCEPTION) << "malloc stride failed."; + } + + for (int i = 0; i < nbDims; i++) { + dim[i] = SizeToInt(shape[i]); + stride[i] = 1; + } + + for (int i = nbDims - 2; i >= 0; i--) { + stride[i] = stride[i + 1] * SizeToInt(shape[i + 1]); + } + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(descriptor, data_type, nbDims, dim, stride), + "cudnnSetTensorNdDescriptor failed"); + + delete[] dim; + dim = nullptr; + delete[] stride; + stride = nullptr; + } + // choose the suitable datatype for cudnn/cublas inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { auto type = kCudnnDtypeMap.find(Type); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index b6ac5a368..82198b37d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -27,6 +27,7 @@ #include "backend/kernel_compiler/gpu/kernel_constants.h" namespace mindspore { namespace kernel { +constexpr int MAX_DIMS = 7; template class BroadcastOpGpuKernel : public GpuKernel { public: @@ -45,9 +46,8 @@ class BroadcastOpGpuKernel : public GpuKernel { S *output = GetDeviceAddress(outputs, 0); if (need_broadcast_) { - Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], - rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, - rhs, output, reinterpret_cast(stream_ptr)); + Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, + reinterpret_cast(stream_ptr)); } else { NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); } @@ -60,10 +60,13 @@ class BroadcastOpGpuKernel : public GpuKernel { auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); need_broadcast_ = IsBroadcast(shape1, shape2); - if (need_broadcast_ && shape1.size() > 4) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + if (need_broadcast_ && shape1.size() > 7) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; } + lhs_shape_.resize(MAX_DIMS, 1); + rhs_shape_.resize(MAX_DIMS, 1); + output_shape_.resize(MAX_DIMS, 1); for (size_t i = 0; i < shape3.size(); i++) { output_shape_[i] = shape3[i]; output_num_ *= shape3[i]; @@ -127,9 +130,9 @@ class BroadcastOpGpuKernel : public GpuKernel { int input1_num_; int input2_num_; int output_num_; - int lhs_shape_[4] = {1, 1, 1, 1}; - int rhs_shape_[4] = {1, 1, 1, 1}; - int output_shape_[4] = {1, 1, 1, 1}; + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index b434ddadd..01c1079a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -83,12 +83,19 @@ class ActivationGpuFwdKernel : public GpuKernel { return true; } std::vector shape; - ShapeNdTo4d(input_shape, &shape); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), "cudnnSetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); + + const int split_dim = 4; + if (input_shape.size() <= split_dim) { + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "cudnnSetTensor4dDescriptor failed"); + } else { + CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); + } + InitSizeLists(); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index 2d7b2012f..47aadc70a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -90,12 +90,18 @@ class ActivationGradGpuKernel : public GpuKernel { return true; } std::vector shape; - ShapeNdTo4d(input_shape, &shape); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), "SetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "SetTensor4dDescriptor failed"); + + const int split_dim = 4; + if (input_shape.size() <= split_dim) { + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "SetTensor4dDescriptor failed"); + } else { + CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); + } InitSizeLists(); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc index 208e217e1..35deb0cd8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc @@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO( CheckValid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CheckValidGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + CheckValid, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, half, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc index 5d3f0f202..081d12da3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc @@ -21,5 +21,8 @@ namespace kernel { MS_REG_GPU_KERNEL_ONE( IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), IOUGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + IOU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + IOUGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_floordiv_op.py b/tests/st/ops/gpu/test_floordiv_op.py index dc7d76807..26c32f921 100644 --- a/tests/st/ops/gpu/test_floordiv_op.py +++ b/tests/st/ops/gpu/test_floordiv_op.py @@ -37,8 +37,8 @@ def test_floor_div(): y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) - x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) - y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32) x3_np = np.random.randint(1, 5, 1).astype(np.float32) y3_np = np.random.randint(1, 5, 1).astype(np.float32) x4_np = np.array(768).astype(np.float32) diff --git a/tests/st/ops/gpu/test_reduce_mean_op.py b/tests/st/ops/gpu/test_reduce_mean_op.py index 867d7e8a3..7033449b5 100644 --- a/tests/st/ops/gpu/test_reduce_mean_op.py +++ b/tests/st/ops/gpu/test_reduce_mean_op.py @@ -70,7 +70,7 @@ x11 = np.random.rand(1, 1, 1, 1).astype(np.float32) axis11 = (0, 1, 2, 3) keep_dims11 = False -x12 = np.random.rand(2, 3, 4, 4).astype(np.float32) +x12 = np.random.rand(2, 3, 4, 4, 5, 6).astype(np.float32) axis12 = -2 keep_dims12 = False -- GitLab