提交 6719169a 编写于 作者: P Peilin Wang

added type support for atomic add and scatternd

fix ci

fix ci
上级 0e27a04d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScatterNdGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ScatterNdGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ScatterNdGpuFwdKernel, int, int)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ScatterNdGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ScatterNdGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ScatterNdGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ScatterNdGpuFwdKernel, short, int) // NOLINT
MS_REG_GPU_KERNEL_TWO(
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ScatterNdGpuFwdKernel, uchar, int)
} // namespace kernel
} // namespace mindspore
......@@ -23,9 +23,9 @@ struct MinimumGradFunc {
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
const T &dy, T *dx1, T *dx2) {
if (grad_x1 && x1 < x2) {
ms_atomic_add(dx1, dy);
MsAtomicAdd(dx1, dy);
} else if (grad_x2 && x1 >= x2) {
ms_atomic_add(dx2, dy);
MsAtomicAdd(dx2, dy);
}
}
};
......@@ -35,9 +35,9 @@ struct MaximumGradFunc {
__device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2,
const T &dy, T *dx1, T *dx2) {
if (grad_x1 && x1 > x2) {
ms_atomic_add(dx1, dy);
MsAtomicAdd(dx1, dy);
} else if (grad_x2 && x1 <= x2) {
ms_atomic_add(dx2, dy);
MsAtomicAdd(dx2, dy);
}
}
};
......
......@@ -15,6 +15,7 @@
*/
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
......
......@@ -61,7 +61,7 @@ __global__ void ResizeNearestNeighborGrad(const int input_size, const T *input,
out_width - 1);
// pos_array[0] N, pos_array[1] C, out_y H, out_x W
output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x;
ms_atomic_add(&output[output_pos], input[pos]);
MsAtomicAdd(&output[output_pos], input[pos]);
}
}
......
......@@ -218,10 +218,10 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
T *dx_3 = dx + offset + y_high * width + x_low;
T *dx_4 = dx + offset + y_high * width + x_high;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
ms_atomic_add(dx_1, g1);
ms_atomic_add(dx_2, g2);
ms_atomic_add(dx_3, g3);
ms_atomic_add(dx_4, g4);
MsAtomicAdd(dx_1, g1);
MsAtomicAdd(dx_2, g2);
MsAtomicAdd(dx_3, g3);
MsAtomicAdd(dx_4, g4);
}
}
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size,
const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1,
S *indices_stride, S *work_shape) {
int i, j;
for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
int write_index = 0;
bool out_bound = false;
i = read_index / block_size;
j = read_index % block_size;
for (size_t k = 0; k < indices_dim_1; k++) {
S indices_i = indices[i * indices_dim_1 + k];
out_bound |= indices_i >= work_shape[k];
write_index += indices_i * indices_stride[k];
}
write_index += j;
out_bound |= write_index >= output_size;
if (!out_bound) {
ms_atomic_add(&output[write_index], update[read_index]);
}
}
}
template <typename T, typename S>
void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
S *work_shape, cudaStream_t stream) {
ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size,
output_size, indices_dim_0, indices_dim_1,
indices_stride, work_shape);
return;
}
template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size,
const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1,
S *indices_stride, S *work_shape) {
int i, j;
for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
int write_index = 0;
bool out_bound = false;
i = read_index / block_size;
j = read_index % block_size;
for (size_t k = 0; k < indices_dim_1; k++) {
S indices_i = indices[i * indices_dim_1 + k];
out_bound |= indices_i >= work_shape[k];
write_index += indices_i * indices_stride[k];
}
write_index += j;
out_bound |= write_index >= output_size;
if (!out_bound) {
MsAtomicAdd(&output[write_index], update[read_index]);
}
}
}
template <typename T, typename S>
void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
S *work_shape, cudaStream_t stream) {
ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size,
output_size, indices_dim_0, indices_dim_1,
indices_stride, work_shape);
return;
}
template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
// NOLINTNEXTLINE
template void ScatterNd<short, int>(int *indices, short *update, short *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void ScatterNd<unsigned char, int>(int *indices, unsigned char *update, unsigned char *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
......@@ -19,11 +19,41 @@
#include <cuda_fp16.h>
inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); }
__device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); }
inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); }
__device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); }
inline __device__ half ms_atomic_add(half *address, half val) {
__device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigned int val) {
return atomicAdd(address, val);
}
__device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT
bool is_4_byte_aligned = ((size_t) address & 2) == 0;
unsigned int *aligned = (unsigned int *) ((size_t) address & ~2);
unsigned int old = *aligned;
unsigned int assumed;
do {
assumed = old;
unsigned int replacement;
if (is_4_byte_aligned) {
replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff);
} else {
replacement = old + ((unsigned int) val << 16);
}
old = atomicCAS(aligned, assumed, replacement);
} while (assumed != old);
if (is_4_byte_aligned) {
return (short) (old & 0xffff); // NOLINT
} else {
return (short) (old >> 16); // NOLINT
}
}
__device__ static inline half MsAtomicAdd(half *address, half val) {
unsigned int *aligned =
reinterpret_cast<unsigned int *>(reinterpret_cast<size_t>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *aligned;
......@@ -42,4 +72,66 @@ inline __device__ half ms_atomic_add(half *address, half val) {
return half(raw);
}
__device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) {
// We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to
// implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but
// unsigned int* must be 4 byte aligned. This variable contains the offset,
// in bytes, of the beginning of address, within the 4 byte aligned space that
// contains it.
size_t address_offset = (size_t) address & 3;
// Address of the 4 byte aligned space that contains address.
unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset);
// Constants which will be used later with __byte_perm. __byte_perm is a cuda
// function which takes 3 unsigned int's (x, y, selector) as parameters and
// returns an int. __byte_perm returns an integer by selecting bytes from x
// and y based on the given selector. The selector 0x3210 in will select all
// four bytes from x, preserving their original order. The position of the
// "4" in the selector indicates the position in the output where the first
// byte of y will end up.
unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210};
// Gets the selector that will select the bytes at address from aligned
unsigned int selector = selectors[address_offset];
unsigned int old = *aligned;
unsigned int assumed = 0;
do {
assumed = old;
// Selects the byte associated with address and put it as the first byte of
// this variable, so that we can add val to the value at address.
unsigned int sum = val + __byte_perm(old, 0, address_offset);
// Takes old and replaces the byte corresponding to address with the sum.
unsigned int replacement = __byte_perm(old, sum, selector);
// Try to replace the old value with the new value
old = atomicCAS(aligned, assumed, replacement);
} while (old != assumed);
// Select the single byte corredsponding to address and return it.
return __byte_perm(old, 0, address_offset);
}
__device__ static inline char MsAtomicAdd(char* address, char val) {
size_t address_offset = (size_t) address & 3;
unsigned int* aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset);
unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210};
unsigned int selector = selectors[address_offset];
unsigned int old = *aligned;
unsigned int assumed = 0;
do {
assumed = old;
unsigned int sum = val + __byte_perm(old, 0, address_offset);
unsigned int replacement = __byte_perm(old, sum, selector);
old = atomicCAS(aligned, assumed, replacement);
} while (old != assumed);
return __byte_perm(old, 0, address_offset);
}
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册