diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc index 3a9aa6e075b69af077ccb54b81343933d47375e9..b2b3ee031ba9fb4474f7c8b131c95789a983397e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc @@ -1,33 +1,39 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - ScatterNd, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ScatterNdGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - ScatterNd, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ScatterNdGpuFwdKernel, half, int) -MS_REG_GPU_KERNEL_TWO( - ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ScatterNdGpuFwdKernel, int, int) -} // namespace kernel -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ScatterNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ScatterNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ScatterNdGpuFwdKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + ScatterNdGpuFwdKernel, short, int) // NOLINT +MS_REG_GPU_KERNEL_TWO( + ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + ScatterNdGpuFwdKernel, uchar, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu index 1e1a9e2da68c90e4c1eaf1038805e2ea8a19b386..fee1e3eb3b96f17c2dec09bf9dc349d3fd4b2b4c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -23,9 +23,9 @@ struct MinimumGradFunc { __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, const T &dy, T *dx1, T *dx2) { if (grad_x1 && x1 < x2) { - ms_atomic_add(dx1, dy); + MsAtomicAdd(dx1, dy); } else if (grad_x2 && x1 >= x2) { - ms_atomic_add(dx2, dy); + MsAtomicAdd(dx2, dy); } } }; @@ -35,9 +35,9 @@ struct MaximumGradFunc { __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, const T &dy, T *dx1, T *dx2) { if (grad_x1 && x1 > x2) { - ms_atomic_add(dx1, dy); + MsAtomicAdd(dx1, dy); } else if (grad_x2 && x1 <= x2) { - ms_atomic_add(dx2, dy); + MsAtomicAdd(dx2, dy); } } }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 040f57c2bf3f9908626ef291055293268b122453..9f8d30df0021572c606e9eb5490dd5f19b435230 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -15,6 +15,7 @@ */ #include + #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" #include "runtime/device/gpu/cuda_common.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu index edb509a38d8906b2ce394ef8aa1ead9165eec2c9..8e61a46ad7be117c7a1c3cda4d5d8854ce5775a7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu @@ -61,7 +61,7 @@ __global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, out_width - 1); // pos_array[0] N, pos_array[1] C, out_y H, out_x W output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x; - ms_atomic_add(&output[output_pos], input[pos]); + MsAtomicAdd(&output[output_pos], input[pos]); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu index 789abcf0f7367f80c11088c7d941c1acc4ef853d..6eeb2e533c72971fc27d0fc96b2b93d5d52a6ac3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -218,10 +218,10 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, T *dx_3 = dx + offset + y_high * width + x_low; T *dx_4 = dx + offset + y_high * width + x_high; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - ms_atomic_add(dx_1, g1); - ms_atomic_add(dx_2, g2); - ms_atomic_add(dx_3, g3); - ms_atomic_add(dx_4, g4); + MsAtomicAdd(dx_1, g1); + MsAtomicAdd(dx_2, g2); + MsAtomicAdd(dx_3, g3); + MsAtomicAdd(dx_4, g4); } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu index 80258e718dd72e3365dfbcda7c75884b1eb80841..c34cd99084b2327c5225bc5c12564a9e8f39fcb1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu @@ -1,70 +1,80 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" -#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" -#include "runtime/device/gpu/cuda_common.h" - -template -__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, - const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, - S *indices_stride, S *work_shape) { - int i, j; - for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; - read_index += blockDim.x * gridDim.x) { - int write_index = 0; - bool out_bound = false; - - i = read_index / block_size; - j = read_index % block_size; - - for (size_t k = 0; k < indices_dim_1; k++) { - S indices_i = indices[i * indices_dim_1 + k]; - out_bound |= indices_i >= work_shape[k]; - write_index += indices_i * indices_stride[k]; - } - - write_index += j; - out_bound |= write_index >= output_size; - - if (!out_bound) { - ms_atomic_add(&output[write_index], update[read_index]); - } - } -} - -template -void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, - const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, - S *work_shape, cudaStream_t stream) { - ScatterNdKernel<<>>(indices, update, output, block_size, input_size, - output_size, indices_dim_0, indices_dim_1, - indices_stride, work_shape); - return; -} - -template void ScatterNd(int *indices, float *update, float *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int *indices_stride, int *work_shape, - cudaStream_t stream); -template void ScatterNd(int *indices, half *update, half *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int *indices_stride, int *work_shape, - cudaStream_t stream); -template void ScatterNd(int *indices, int *update, int *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int *indices_stride, int *work_shape, - cudaStream_t stream); +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, + const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, + S *indices_stride, S *work_shape) { + int i, j; + for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; + read_index += blockDim.x * gridDim.x) { + int write_index = 0; + bool out_bound = false; + + i = read_index / block_size; + j = read_index % block_size; + + for (size_t k = 0; k < indices_dim_1; k++) { + S indices_i = indices[i * indices_dim_1 + k]; + out_bound |= indices_i >= work_shape[k]; + write_index += indices_i * indices_stride[k]; + } + + write_index += j; + out_bound |= write_index >= output_size; + + if (!out_bound) { + MsAtomicAdd(&output[write_index], update[read_index]); + } + } +} + +template +void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream) { + ScatterNdKernel<<>>(indices, update, output, block_size, input_size, + output_size, indices_dim_0, indices_dim_1, + indices_stride, work_shape); + return; +} + +template void ScatterNd(int *indices, float *update, float *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, half *update, half *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, int *update, int *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +// NOLINTNEXTLINE +template void ScatterNd(int *indices, short *update, short *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, unsigned char *update, unsigned char *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh index 2b216baa8e94a411c6b689a2c4064437f2b5749c..e5a9ded119be11352368b004b9d6f2d700bdf198 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh @@ -19,11 +19,41 @@ #include -inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } +__device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } -inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); } +__device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); } -inline __device__ half ms_atomic_add(half *address, half val) { +__device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigned int val) { + return atomicAdd(address, val); +} + +__device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT + bool is_4_byte_aligned = ((size_t) address & 2) == 0; + unsigned int *aligned = (unsigned int *) ((size_t) address & ~2); + unsigned int old = *aligned; + unsigned int assumed; + + do { + assumed = old; + unsigned int replacement; + + if (is_4_byte_aligned) { + replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff); + } else { + replacement = old + ((unsigned int) val << 16); + } + + old = atomicCAS(aligned, assumed, replacement); + } while (assumed != old); + + if (is_4_byte_aligned) { + return (short) (old & 0xffff); // NOLINT + } else { + return (short) (old >> 16); // NOLINT + } +} + +__device__ static inline half MsAtomicAdd(half *address, half val) { unsigned int *aligned = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *aligned; @@ -42,4 +72,66 @@ inline __device__ half ms_atomic_add(half *address, half val) { return half(raw); } +__device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) { + // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to + // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but + // unsigned int* must be 4 byte aligned. This variable contains the offset, + // in bytes, of the beginning of address, within the 4 byte aligned space that + // contains it. + size_t address_offset = (size_t) address & 3; + + // Address of the 4 byte aligned space that contains address. + unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset); + + // Constants which will be used later with __byte_perm. __byte_perm is a cuda + // function which takes 3 unsigned int's (x, y, selector) as parameters and + // returns an int. __byte_perm returns an integer by selecting bytes from x + // and y based on the given selector. The selector 0x3210 in will select all + // four bytes from x, preserving their original order. The position of the + // "4" in the selector indicates the position in the output where the first + // byte of y will end up. + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + + // Gets the selector that will select the bytes at address from aligned + unsigned int selector = selectors[address_offset]; + + unsigned int old = *aligned; + unsigned int assumed = 0; + + do { + assumed = old; + + // Selects the byte associated with address and put it as the first byte of + // this variable, so that we can add val to the value at address. + unsigned int sum = val + __byte_perm(old, 0, address_offset); + + // Takes old and replaces the byte corresponding to address with the sum. + unsigned int replacement = __byte_perm(old, sum, selector); + + // Try to replace the old value with the new value + old = atomicCAS(aligned, assumed, replacement); + } while (old != assumed); + // Select the single byte corredsponding to address and return it. + return __byte_perm(old, 0, address_offset); +} + +__device__ static inline char MsAtomicAdd(char* address, char val) { + size_t address_offset = (size_t) address & 3; + unsigned int* aligned = reinterpret_cast(reinterpret_cast(address) - address_offset); + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + unsigned int selector = selectors[address_offset]; + unsigned int old = *aligned; + unsigned int assumed = 0; + + do { + assumed = old; + + unsigned int sum = val + __byte_perm(old, 0, address_offset); + unsigned int replacement = __byte_perm(old, sum, selector); + + old = atomicCAS(aligned, assumed, replacement); + } while (old != assumed); + return __byte_perm(old, 0, address_offset); +} + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_