From 604bb2a569615a5a600821d36b19ddac488fb704 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 22 Jul 2021 16:45:06 +0800 Subject: [PATCH] feat(mgb/dnn): add int atomic add for megdnn GitOrigin-RevId: 00d5d752d3f3d91f3fd581e816e3b1280bad4c31 --- dnn/src/cuda/atomic_add.cuh | 192 ++++++++++++++++++++++++++++++++++++ dnn/src/cuda/utils.cuh | 67 +------------ 2 files changed, 196 insertions(+), 63 deletions(-) create mode 100644 dnn/src/cuda/atomic_add.cuh diff --git a/dnn/src/cuda/atomic_add.cuh b/dnn/src/cuda/atomic_add.cuh new file mode 100644 index 000000000..adab0ee67 --- /dev/null +++ b/dnn/src/cuda/atomic_add.cuh @@ -0,0 +1,192 @@ +/** + * \file dnn/src/cuda/atomic.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include +#include +#include "cuda.h" +#include "include/megdnn/dtype.h" + +namespace megdnn { +namespace cuda { + +#if MEGDNN_CC_CUDA +template +static inline MEGDNN_DEVICE void atomic_add(T* address, T val); + +template <> +MEGDNN_DEVICE void atomic_add(dt_float32* address, dt_float32 val) { + ::atomicAdd(reinterpret_cast(address), static_cast(val)); +} + +// overload atomicAdd for half precision +// Taken from: +// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomic.cuh +template <> +MEGDNN_DEVICE void atomic_add(dt_float16* address, dt_float16 val) { +#if (__CUDA_ARCH__ < 700 || __CUDACC_VER_MAJOR__ <= 9) + unsigned int* address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short data = reinterpret_cast(address) & 2 + ? (old >> 16) + : (old & 0xffff); + dt_float16 hsum = *reinterpret_cast(&data); + hsum += val; + data = *reinterpret_cast(&hsum); + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (data << 16) + : (old & 0xffff0000) | data; + old = ::atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +#else + ::atomicAdd(reinterpret_cast<__half*>(address), static_cast<__half>(val)); +#endif +} + +template <> +MEGDNN_DEVICE void atomic_add(dt_bfloat16* address, dt_bfloat16 val) { + unsigned int* address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short data = reinterpret_cast(address) & 2 + ? (old >> 16) + : (old & 0xffff); + dt_bfloat16 hsum = *reinterpret_cast(&data); + hsum += val; + data = *reinterpret_cast(&hsum); + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (data << 16) + : (old & 0xffff0000) | data; + old = ::atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +template +struct AtomicAddIntegerImpl; + +template +struct AtomicAddIntegerImpl { + inline __device__ void operator()(T* address, T val) { + size_t offset = (size_t)address & 3; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + old_byte = (old >> shift) & 0xff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(static_cast(val) + + static_cast(old_byte)); + // newval = static_cast(THCNumerics::add(val, + // old_byte)); + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + } +}; + +template +struct AtomicAddIntegerImpl { + inline __device__ void operator()(T* address, T val) { + size_t offset = (size_t)address & 2; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + bool is_32_align = offset; + uint32_t old = *address_as_ui; + uint32_t old_bytes; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + old_bytes = is_32_align ? old >> 16 : old & 0xffff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(static_cast(val) + + static_cast(old_bytes)); + // newval = static_cast(THCNumerics::add(val, + // old_bytes)); + newval = is_32_align ? (old & 0xffff) | (newval << 16) + : (old & 0xffff0000) | newval; + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); + } +}; + +template <> +MEGDNN_DEVICE void atomic_add(dt_int32* address, dt_int32 val) { + ::atomicAdd(reinterpret_cast(address), static_cast(val)); +} + +// we assume quantized int in the same tensor with same scale +template <> +MEGDNN_DEVICE void atomic_add(dt_qint32* address, dt_qint32 val) { + ::atomicAdd(reinterpret_cast(address), val.as_int32()); +} + +template <> +MEGDNN_DEVICE void atomic_add(dt_int16* address, dt_int16 val) { + AtomicAddIntegerImpl()(address, val); +} + +template <> +MEGDNN_DEVICE void atomic_add(dt_uint16* address, dt_uint16 val) { + AtomicAddIntegerImpl()(address, val); +} + +// we assume quantized int in the same tensor with same scale +template <> +MEGDNN_DEVICE void atomic_add(dt_qint16* address, dt_qint16 val) { + AtomicAddIntegerImpl()( + reinterpret_cast(address), val.as_int16()); +} +// be careful! may case over flow +#if 0 +template <> +MEGDNN_DEVICE void atomic_add(dt_int8* address, dt_int8 val) { + AtomicAddIntegerImpl()(address, val); +} + +template <> +MEGDNN_DEVICE void atomic_add(dt_uint8* address, dt_uint8 val) { + AtomicAddIntegerImpl()(address, val); +} + +// we assume quantized int in the same tensor with same scale +template <> +MEGDNN_DEVICE void atomic_add(dt_quint8* address, dt_quint8 val) { + AtomicAddIntegerImpl()(reinterpret_cast(address), val.as_uint8()); +} + +// we assume quantized int in the same tensor with same scale +template <> +MEGDNN_DEVICE void atomic_add(dt_qint8* address, dt_qint8 val) { + AtomicAddIntegerImpl()(reinterpret_cast(address), val.as_int8()); +} +#endif + +#endif +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index e3674ac03..0d7c4f9e5 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -19,8 +20,9 @@ #include #include #include "cuda.h" -#include "src/cuda/cudnn_with_check.h" #include "cutlass/cutlass.h" +#include "src/cuda/cudnn_with_check.h" +#include "src/cuda/atomic_add.cuh" #define cuda_check(_x) \ do { \ @@ -240,67 +242,6 @@ struct CudaDTypeParamImpl : DTypeParamImpl { }; #if MEGDNN_CC_CUDA -template -static inline MEGDNN_DEVICE void atomic_add(T* address, T val); - -template <> -MEGDNN_DEVICE void atomic_add(dt_float32* address, dt_float32 val) { - ::atomicAdd(reinterpret_cast(address), static_cast(val)); -} - -// overload atomicAdd for half precision -// Taken from: -// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomic.cuh -template <> -MEGDNN_DEVICE void atomic_add(dt_float16* address, dt_float16 val) { -#if (__CUDA_ARCH__ < 700 || __CUDACC_VER_MAJOR__ <= 9) - unsigned int* address_as_ui = reinterpret_cast( - reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do { - assumed = old; - unsigned short data = reinterpret_cast(address) & 2 - ? (old >> 16) - : (old & 0xffff); - dt_float16 hsum = *reinterpret_cast(&data); - hsum += val; - data = *reinterpret_cast(&hsum); - old = reinterpret_cast(address) & 2 - ? (old & 0xffff) | (data << 16) - : (old & 0xffff0000) | data; - old = ::atomicCAS(address_as_ui, assumed, old); - } while (assumed != old); -#else - ::atomicAdd(reinterpret_cast<__half*>(address), static_cast<__half>(val)); -#endif -} - -template <> -MEGDNN_DEVICE void atomic_add(dt_bfloat16* address, dt_bfloat16 val) { - unsigned int* address_as_ui = reinterpret_cast( - reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do { - assumed = old; - unsigned short data = reinterpret_cast(address) & 2 - ? (old >> 16) - : (old & 0xffff); - dt_bfloat16 hsum = *reinterpret_cast(&data); - hsum += val; - data = *reinterpret_cast(&hsum); - old = reinterpret_cast(address) & 2 - ? (old & 0xffff) | (data << 16) - : (old & 0xffff0000) | data; - old = ::atomicCAS(address_as_ui, assumed, old); - } while (assumed != old); -} - static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { #if __CUDA_ARCH__ >= 610 // clang-format off -- GitLab