From 3b09299b69d7bf4c1ee6dc92df890c5c991f87c4 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Fri, 12 Jun 2020 11:14:59 +0800 Subject: [PATCH] gpu codex warning fix --- .../kernel/gpu/cuda_impl/layer_norm_impl.cu | 32 +++++++++---------- .../gpu/data/dataset_iterator_kernel.cc | 3 +- .../kernel/gpu/math/broadcast_gpu_kernel.h | 8 ++--- .../gpu/math/broadcast_grad_gpu_kernel.h | 8 ++--- .../ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc | 4 +-- .../kernel/gpu/nn/dropout_grad_kernel.cc | 4 +-- 6 files changed, 28 insertions(+), 31 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu index db3367374..cef74dc8b 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu @@ -23,7 +23,7 @@ constexpr int NUM_PER_THREAD_REDUCE = 4; constexpr int WARP_SIZE = 32; template -inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& val) { +inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) { // Welford Algorithm: // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) @@ -34,7 +34,7 @@ inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& } template -inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T& v2, const T& n2) { +inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { if (n2 == 0) { return; } @@ -46,7 +46,7 @@ inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T } template -inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* mean, T* var, T* num) { +inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -60,7 +60,7 @@ inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* } template -inline __device__ void WarpReduce(T* mean, T* var, T* num) { +inline __device__ void WarpReduce(T *mean, T *var, T *num) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); T var_other = __shfl_down_sync(0xffffffff, var[0], delta); @@ -70,8 +70,8 @@ inline __device__ void WarpReduce(T* mean, T* var, T* num) { } template -inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, T* mean_addr, T* var_addr, - T* share_mem) { +inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, + T *share_mem) { if (threadIdx.x >= col_dim) { return; } @@ -96,15 +96,15 @@ inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, __syncthreads(); if (threadIdx.x == 0) { - mean_addr[blockIdx.x] = share_mem[0]; // todo: blockDim.x < row + mean_addr[blockIdx.x] = share_mem[0]; share_mem[1] /= col_dim; var_addr[blockIdx.x] = share_mem[1]; } } template -inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& param_dim, const T* x, - const T* share_mem, const T* gamma, const T* beta, const T epsilon, T* y) { +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x, + const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = row * col_dim + col; int i = pos % param_dim; @@ -113,13 +113,13 @@ inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& } template -__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* x, - const T* gamma, const T* beta, T* y, T* mean_addr, T* var_addr) { +__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { T mean = 0; T var = 0; T num = 0; - const T* block_addr = x + row * col_dim; + const T *block_addr = x + row * col_dim; extern __shared__ T share_mem[]; ThreadReduce(col_dim, block_addr, &mean, &var, &num); @@ -132,8 +132,8 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int } template -void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* x, - const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream) { +void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { const dim3 block(row_dim); const dim3 thread(256); // keep the mean/var/num after warp reduce @@ -143,6 +143,6 @@ void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, con var); } -template void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, - const float* x, const float* gamma, const float* beta, float* y, float* mean, float* var, +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, + const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc index d416d7df6..13ca191b0 100644 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc @@ -96,7 +96,8 @@ bool DatasetIteratorKernel::Launch(const std::vector &, const std::v } for (size_t i = 0; i < output_size_list_.size(); i++) { - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs[i]->addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, + void *output_addr = GetDeviceAddress(outputs, i); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, reinterpret_cast(stream)), "Cuda Memcpy Failed"); addr = reinterpret_cast(addr) + output_size_list_[i]; diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h index 314f992c2..be7d3a19d 100644 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h @@ -68,14 +68,14 @@ class BroadcastOpGpuKernel : public GpuKernel { output_shape_[i] = shape3[i]; output_num_ *= shape3[i]; } - int offset = shape3.size() - shape1.size(); + int lhs_offset = shape3.size() - shape1.size(); for (size_t j = 0; j < shape1.size(); j++) { - lhs_shape_[j + offset] = shape1[j]; + lhs_shape_[j + lhs_offset] = shape1[j]; input1_num_ *= shape1[j]; } - offset = shape3.size() - shape2.size(); + int rhs_offset = shape3.size() - shape2.size(); for (size_t k = 0; k < shape2.size(); k++) { - rhs_shape_[k + offset] = shape2[k]; + rhs_shape_[k + rhs_offset] = shape2[k]; input2_num_ *= shape2[k]; } diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h index 3e1f91b5b..f1eb5fecf 100644 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h @@ -74,14 +74,14 @@ class BroadcastOpGradGpuKernel : public GpuKernel { dy_shape_[i] = shape3[i]; output_num_ *= shape3[i]; } - int offset = shape3.size() - shape1.size(); + int x1_offset = shape3.size() - shape1.size(); for (size_t i = 0; i < shape1.size(); i++) { - x1_shape_[i + offset] = shape1[i]; + x1_shape_[i + x1_offset] = shape1[i]; input1_num_ *= shape1[i]; } - offset = shape3.size() - shape2.size(); + int x2_offset = shape3.size() - shape2.size(); for (size_t i = 0; i < shape2.size(); i++) { - x2_shape_[i + offset] = shape2[i]; + x2_shape_[i + x2_offset] = shape2[i]; input2_num_ *= shape2[i]; } diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc index 87783add6..b84dc628e 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc @@ -68,14 +68,12 @@ void DropoutGpuFwdKernel::DestroyResource() noexcept {} void DropoutGpuFwdKernel::InitSizeLists() { size_t input_size = num_count_ * sizeof(float); - size_t workspace_size = 0; input_size_list_.push_back(input_size); output_size_list_.push_back(input_size); // output size: the same with input size output_size_list_.push_back(input_size); // mask size: the same with input size - workspace_size_list_.push_back(workspace_size); } -bool DropoutGpuFwdKernel::Launch(const std::vector &inputs, const std::vector &workspace, +bool DropoutGpuFwdKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) { if (is_null_input_) { return true; diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc index 4517f1bb3..2194805e9 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc @@ -66,15 +66,13 @@ void DropoutGradGpuFwdKernel::InitSizeLists() { size_t dy_size = num_count_ * sizeof(float); size_t mask_size = dy_size; size_t dx_size = dy_size; - size_t workspace_size = 0; input_size_list_.push_back(dy_size); input_size_list_.push_back(mask_size); output_size_list_.push_back(dx_size); - workspace_size_list_.push_back(workspace_size); } -bool DropoutGradGpuFwdKernel::Launch(const std::vector &inputs, const std::vector &workspace, +bool DropoutGradGpuFwdKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) { if (is_null_input_) { return true; -- GitLab