From 1f8e930e28ec52232cae91baaa93efc4f9011be4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 21 Jul 2022 14:03:19 +0800 Subject: [PATCH] feat(cuda): add int4 ptx 128x128 mma kernel GitOrigin-RevId: 5a8b9c3f8eab59ed8d1daf9bbaf2c81cdc82ca5b --- dnn/src/cuda/ptx/uint4_int4/base.cu | 39 + dnn/src/cuda/ptx/uint4_int4/base.cuh | 109 ++ .../fuse_z_imma8832_ldgsts16_128x128_relu.cu | 1096 +++++++++++++++++ .../cuda/ptx/uint4_int4/imma8832_128x128.cuh | 26 + .../imma8832_ldgsts16_128x128_relu.cu | 1089 ++++++++++++++++ dnn/src/cuda/ptx/uint4_int4/kern.cuh | 20 + dnn/src/cuda/ptx/uint4_int4/macro.cuh | 348 ++++++ dnn/src/cuda/ptx/uint4_int4/tools.cuh | 49 + 8 files changed, 2776 insertions(+) create mode 100644 dnn/src/cuda/ptx/uint4_int4/base.cu create mode 100644 dnn/src/cuda/ptx/uint4_int4/base.cuh create mode 100644 dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu create mode 100644 dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh create mode 100644 dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu create mode 100644 dnn/src/cuda/ptx/uint4_int4/kern.cuh create mode 100644 dnn/src/cuda/ptx/uint4_int4/macro.cuh create mode 100644 dnn/src/cuda/ptx/uint4_int4/tools.cuh diff --git a/dnn/src/cuda/ptx/uint4_int4/base.cu b/dnn/src/cuda/ptx/uint4_int4/base.cu new file mode 100644 index 000000000..db82c81ff --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/base.cu @@ -0,0 +1,39 @@ +#include "./base.cuh" + +using namespace convolution; + +Uint32Fastdiv::Uint32Fastdiv() { + memset(this, 0, sizeof(Uint32Fastdiv)); +} + +Uint32Fastdiv& Uint32Fastdiv::operator=(uint32_t d) { + m_divisor = d; + constexpr uint32_t MAX_U32 = ~0u; + m_inc_dividend = 0; + m_divisor_is_not_1 = ~0u; + if (!(d & (d - 1))) { + // power of 2 + m_mul = 1u << 31; + int p = 0; + while ((1u << p) < d) + ++p; + m_shift = p ? p - 1 : 0; + if (d == 1) + m_divisor_is_not_1 = 0; + return *this; + } + auto n_bound = uint64_t(d / 2 + 1) * MAX_U32; + uint32_t shift = 32; + while ((1ull << shift) < n_bound) + ++shift; + uint64_t mdst = 1ull << shift; + int64_t delta = d - mdst % d; + m_mul = mdst / d + 1; + if ((uint64_t)delta > d / 2) { + delta -= d; + --m_mul; + m_inc_dividend = 1; + } + m_shift = shift - 32; + return *this; +} diff --git a/dnn/src/cuda/ptx/uint4_int4/base.cuh b/dnn/src/cuda/ptx/uint4_int4/base.cuh new file mode 100644 index 000000000..861e01be7 --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/base.cuh @@ -0,0 +1,109 @@ +#pragma once + +#include +#include + +#if ((__CUDACC_VER_MAJOR__ > 11) || \ + (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) +#define SM80_SUPPORTED +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define SM80_ENABLED +#endif +#endif + +namespace convolution { +class Uint32Fastdiv { + uint32_t m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; + +public: + Uint32Fastdiv(); + + Uint32Fastdiv(uint32_t d) { operator=(d); } + + //! set the divisor to be d + Uint32Fastdiv& operator=(uint32_t d); + + //! caller must ensure that dividend would not exceed this number + static constexpr uint32_t MAX_DIVIDEND = ~0u - 1; + + __device__ __forceinline__ uint32_t divisor() const { return m_divisor; } + + __device__ __forceinline__ uint32_t divide(uint32_t dividend) const { + uint32_t ans_for_one = dividend & ~m_divisor_is_not_1, + dfix = dividend + m_inc_dividend, +#if __CUDA_ARCH__ + hi32 = __umulhi(dfix, m_mul), +#else + hi32 = ((uint64_t)dfix * m_mul) >> 32, +#endif + ans = hi32 >> m_shift; + + return (ans & m_divisor_is_not_1) | ans_for_one; + } +}; + +static __forceinline__ __device__ uint32_t +operator/(uint32_t a, const Uint32Fastdiv& d) { + return d.divide(a); +} + +static __forceinline__ __device__ uint32_t +operator%(uint32_t a, const Uint32Fastdiv& d) { + return a - d.divisor() * d.divide(a); +} + +struct Conv2dInt4Param { + uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow; + uint32_t ibs, ics, ihs; + uint32_t obs, ocs, ohs; + uint32_t icfhfw; + uint32_t nhw; + Uint32Fastdiv div_ohow; + Uint32Fastdiv div_ow; + Conv2dInt4Param( + uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw, + uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc, + uint32_t oh, uint32_t ow, uint32_t interleaved) + : n(n), + ic(ic), + ih(ih), + iw(iw), + fh(fh), + fw(fw), + sh(sh), + sw(sw), + ph(ph), + pw(pw), + oc(oc), + oh(oh), + ow(ow) { + constexpr uint32_t size_bits = 4; + // all stride size in bytes + ibs = ic * ih * iw * size_bits / 8; + ics = ih * iw * interleaved * size_bits / 8; + ihs = iw * interleaved * size_bits / 8; + obs = oc * oh * ow * size_bits / 8; + ocs = oh * ow * interleaved * size_bits / 8; + ohs = ow * interleaved * size_bits / 8; + icfhfw = ic * fh * fw; + nhw = n * oh * ow; + div_ohow = oh * ow; + div_ow = ow; + } +}; + +struct Conv2dConstantOffsetParam { + int32_t begin; + int32_t size; + int32_t max; + int32_t rewind; +}; + +#define CONSTANT_BUFFER_SIZE 848 + +struct Conv2dConstantOffset { + Conv2dConstantOffsetParam c_offset_param; + int c_offset[CONSTANT_BUFFER_SIZE]; +}; + +} // namespace convolution diff --git a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu new file mode 100644 index 000000000..9c7f4262b --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu @@ -0,0 +1,1096 @@ +#include + +#include +#include "./imma8832_128x128.cuh" +#include "./kern.cuh" +#include "./macro.cuh" +#include "./tools.cuh" + +using namespace convolution; + +namespace { +#ifdef SM80_ENABLED +extern "C" __device__ void g2s_int4(const int4* gm, int4* sm) { + unsigned sm_addr = get_smem_pointer(sm); + const int SizeInBytes = 16; +#if ENABLE_L2_PREFETCH + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(sm_addr), + "l"(gm), "n"(SizeInBytes)); +#else + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(sm_addr), "l"(gm), + "n"(SizeInBytes)); +#endif +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does +/// not block. +#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::) + +/// Blocks until all but previous cp.async.commit_group operations have +/// committed. +#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N)) +#endif + +extern "C" __global__ void __launch_bounds__(256) + ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu( + const int8_t* __restrict__ src, const int8_t* __restrict__ filter, + const float* __restrict__ bias, const int8_t* __restrict__ z, + int8_t* __restrict__ dst, float alpha, float beta, float gamma, + uint32_t pk_src_zero_point, int32_t z_zero_point, float dst_zero_point, + uint32_t relu, Conv2dInt4Param param, + Conv2dConstantOffset conv2d_constant) { +#ifdef SM80_ENABLED + const int stages = 3; + const uint32_t tid = threadIdx.x; + const uint32_t bidx = blockIdx.x; + const uint32_t bidy = blockIdx.y; + extern __shared__ int32_t smem[]; // (128 + 128)*128/8*stages + int2 reg_acc[reg_m][reg_n]; + int4 reg_src[2][reg_nd4]; + int4 reg_flt[2][reg_md4]; + // use in other way, maybe use reg_ser/flt + int4 reg_src_cache[2]; + int4 reg_filter_cache[4]; + + uint32_t tid127 = (tid & 127); + uint32_t section = (tid127 >> 1); + uint32_t residue = ((tid127 << 5) & 63); + uint32_t nhw = bidx * BN + section; + + uint32_t tn, hw, toh, tow; + int tih, tiw; + int h_start[2]; + int h_end[2]; + int w_start[2]; + int w_end[2]; + bool g[2]; + const int8_t* __restrict__ g_src_ptr[4]; + for (int i = 0; i < 2; i++) { + if (i != 0) { + nhw += 64; + } + tn = nhw / param.div_ohow; + hw = nhw % param.div_ohow; + toh = hw / param.div_ow; + tow = hw % param.div_ow; + tih = toh * param.sh - param.ph; + tiw = tow * param.sw - param.pw; + g[i] = tn < param.n; + h_start[i] = -tih; + h_end[i] = param.ih - tih; + w_start[i] = -tiw; + w_end[i] = param.iw - tiw; + // param's members have been converted to byte offset and int4 offset need to + // div 2 + int src_offset = tn * param.ibs + tih * param.ihs + + ((int)(tiw * packed_channel + residue) >> 1); + g_src_ptr[i * 2] = src + src_offset; + g_src_ptr[i * 2 + 1] = g_src_ptr[i * 2]; + } + + const uint32_t section_section = (section >> 2); + const uint32_t section_residue = (section & 3); + const uint32_t section_factor = ((section & 15) >> 2); + const uint32_t crosswise_offset = + ((section_residue >> 1) << 4) + + (((section_residue & 1) ^ (section_factor >> 1)) << 3); + const uint32_t residue_offset = ((residue >> 5) ^ (section_factor & 1)) << 2; + + // next + 64 * BK / 8 + int32_t* write_src_s[2]; + write_src_s[0] = + smem + section_section * BK / 2 + crosswise_offset + residue_offset; + write_src_s[1] = write_src_s[0] + 32; + + int iter = (param.icfhfw >> 6); + + uint32_t tid31 = (tid & 31); + uint32_t warp_idx = (tid >> 5); + uint32_t warp_strided = (warp_idx << 2); + uint32_t htid = (tid31 >> 4); + const uint32_t flt_strided = bidy * BM / 8 + warp_strided; + bool guard = flt_strided * 8 < param.oc && iter > htid; + // icfhfw * 8/2 is a stride + const int8_t* __restrict__ g_filter_ptr0 = + filter + flt_strided * (param.icfhfw * 4) + (tid31 << 4); + const int8_t* __restrict__ g_filter_ptr1 = g_filter_ptr0 + (param.icfhfw * 4); + const int8_t* __restrict__ g_filter_ptr2 = g_filter_ptr0 + (param.icfhfw * 8); + const int8_t* __restrict__ g_filter_ptr3 = g_filter_ptr0 + (param.icfhfw * 12); + // next + BK * 8 / (INT32/INT4) + uint32_t q = (tid31 >> 3); + uint32_t r = (tid31 & 7); + int32_t* write_flt_s = smem + BN * BK / 8 + warp_strided * BK + ((q & 1) << 6) + + ((q >> 1) << 5) + (r << 2); + uint32_t quad_idx = (tid31 >> 2); + uint32_t idx_in_quad = (tid & 3); + uint32_t quad_factor = ((tid & 15) >> 2); + uint32_t crosswise = + ((idx_in_quad >> 1) << 4) + (((idx_in_quad & 1) ^ (quad_factor >> 1)) << 3); + uint32_t warp_x = (warp_idx >> 1); + uint32_t warp_y = (warp_idx & 1); + + int32_t* read_src_s_0 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise + + ((0 ^ (quad_factor & 1)) << 2); + int32_t* read_src_s_1 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise + + ((1 ^ (quad_factor & 1)) << 2); + int32_t* read_flt_s_0 = smem + BN * BK / 8 + (warp_y * 8 * BK) + + (quad_idx * BK / 2) + crosswise + + ((0 ^ (quad_factor & 1)) << 2); + int32_t* read_flt_s_1 = smem + BN * BK / 8 + (warp_y * 8 * BK) + + (quad_idx * BK / 2) + crosswise + + ((1 ^ (quad_factor & 1)) << 2); + +#pragma unroll + for (int i = 0; i < reg_m; i++) { +#pragma unroll + for (int j = 0; j < reg_n; j++) { + reg_acc[i][j] = make_int2(0, 0); + } + } + + const int smem_switch = 4096; + const int smem_switch_back = -smem_switch * (stages - 1); + int stage = 0; + uint32_t offset[2] = {0, 2}; // high & low + int src_step[2], x[2], y[2]; + + // global mem --> shared mem, stage 0 + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + + bool guard0[2], guard1[2]; + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] && + y[0] < w_end[0] && iter > 0; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] && + y[1] < w_end[0] && iter > 1; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] && + y[0] < w_end[1] && iter > 0; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] && + y[1] < w_end[1] && iter > 1; + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + cp_async_fence(); + + // global mem --> shared mem, stage 1 -> stage n + iter -= 2; + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] && + y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] && + y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] && + y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] && + y[1] < w_end[1]; + + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + for (; iter >= 2 && stage < stages - 1; iter -= 2) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + + stage++; + cp_async_fence(); + + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && + y[0] >= w_start[0] && y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && + y[1] >= w_start[0] && y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && + y[0] >= w_start[1] && y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && + y[1] >= w_start[1] && y[1] < w_end[1]; + + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + } + bool is_copy = false; + if (iter == 1 && stage != stages - 1) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard && iter > htid) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + is_copy = true; + cp_async_fence(); + } + + bool only_one_stage = (stage == 1) ? true : false; + if (stage >= 2) { + cp_async_wait(stages - 2); + } else { + cp_async_wait(0); + } + + __syncthreads(); + + // read fuse_z + int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point), + make_int2(z_zero_point, z_zero_point)}; + int d_offset = (bidy * (BM >> 6) + warp_y) * param.ocs + (idx_in_quad << 3); + const int8_t* __restrict__ g_z_ptr = z + d_offset; + section = tid31 >> 2; + size_t nhw_post0 = bidx * BN + warp_x * 64 + section; + size_t nhw_post1 = nhw_post0 + 8; + size_t nhw_post2 = nhw_post0 + 16; + size_t nhw_post3 = nhw_post0 + 24; + size_t stg_oc = bidy * BM + (warp_y << 6); + int* g_offset = ((int*)®_filter_cache); + bool stg_guard[8]; +#pragma unroll + for (int y = 0; y < reg_m; y += 4) { + LDG_4x1(reg_fuse_z, g_offset, y) + + nhw_post0 += 32; + nhw_post1 += 32; + nhw_post2 += 32; + nhw_post3 += 32; + } + + for (; iter >= 2; iter -= 2) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + cp_async_fence(); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + if (stage == stages) { + stage = 0; + write_src_s[0] += smem_switch_back; + write_src_s[1] += smem_switch_back; + write_flt_s += smem_switch_back; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + + int src_step[2]; + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && + y[0] >= w_start[0] && y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && + y[1] >= w_start[0] && y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && + y[0] >= w_start[1] && y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && + y[1] >= w_start[1] && y[1] < w_end[1]; + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + cp_async_wait(stages - 2); + __syncthreads(); + } + + if (iter > 0) { + if (!is_copy) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard && iter > htid) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = + make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = + make_int4(0, 0, 0, 0); + } + cp_async_fence(); + } +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + stage++; + if (stage == stages) { + stage = 0; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + cp_async_wait(stages - 2); + } + + if (!only_one_stage) { +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + stage++; + if (stage == stages) { + stage = 0; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + cp_async_wait(0); + } + + guard = iter < 0; +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +// compute +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1 && !(k_inner == 1 && guard)) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + if (k_inner == 1 && guard) { + break; + } + } + + __syncthreads(); + + /// output + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + mul_v4(load_bias0, load_bias0, beta); + mul_v4(load_bias1, load_bias1, beta); + mul_v4(load_bias2, load_bias2, beta); + mul_v4(load_bias3, load_bias3, beta); + } + + int8_t* __restrict__ g_dst_ptr = dst + d_offset; + +#pragma unroll + for (int y = 0; y < reg_m; y += 4) { + I2F_4x8(reg_acc, y, 0); + FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); + PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); + STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); + } +#endif +} +} // namespace + +namespace megdnn { +namespace cuda { +namespace ptx { +void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params) { +#ifdef SM80_SUPPORTED + cudaFuncSetAttribute( + ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu, + cudaFuncAttributeMaxDynamicSharedMemorySize, 49152); + + ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu<<< + grid, block, 49152, stream>>>( + *((int8_t**)params[0]), *((int8_t**)params[1]), *((float**)params[2]), + *((int8_t**)params[3]), *((int8_t**)params[4]), *((float*)params[5]), + *((float*)params[6]), *((float*)params[7]), *((uint32_t*)params[8]), + *((uint32_t*)params[9]), *((float*)params[10]), *((uint32_t*)params[11]), + *((Conv2dInt4Param*)params[12]), *((Conv2dConstantOffset*)params[13])); +#endif +} +} // namespace ptx +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh b/dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh new file mode 100644 index 000000000..0f3ec79c5 --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh @@ -0,0 +1,26 @@ +#pragma once + +#include "./base.cuh" + +#define TX 128 +#define TY 1 +#define BM 128 +#define BN 128 +#define BK 128 +#define mma_m 16 +#define mma_n 8 +#define mma_k 64 +#define reg_m 8 +#define reg_n 8 +#define packed_channel 64 +#define BKd32 (BK / 32) +#define BKd64 (BK / 64) +#define reg_md4 (reg_m >> 2) +#define WARPS (TX / 32) +#define cache_per_warp 128 +#define reg_nd4 (reg_n >> 2) +#define ldg_src (BN * BK / (16 * TX)) +#define ldg_filter (BM * BK / (16 * TX)) +#define ldg_width 16 + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu new file mode 100644 index 000000000..eeacceda6 --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu @@ -0,0 +1,1089 @@ +#include + +#include +#include "./imma8832_128x128.cuh" +#include "./kern.cuh" +#include "./macro.cuh" +#include "./tools.cuh" + +using namespace convolution; + +namespace { +#ifdef SM80_ENABLED +extern "C" __device__ void g2s_int4(const int4* gm, int4* sm) { + unsigned sm_addr = get_smem_pointer(sm); + const int SizeInBytes = 16; +#if ENABLE_L2_PREFETCH + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(sm_addr), + "l"(gm), "n"(SizeInBytes)); +#else + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(sm_addr), "l"(gm), + "n"(SizeInBytes)); +#endif +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does +/// not block. +#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::) + +/// Blocks until all but previous cp.async.commit_group operations have +/// committed. +#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N)) +#endif + +extern "C" __global__ void __launch_bounds__(256) + ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu( + const int8_t* __restrict__ src, int8_t* __restrict__ filter, + const float* __restrict__ bias, int8_t* __restrict__ dst, float alpha, + float beta, uint32_t pk_src_zero_point, float dst_zero_point, + uint32_t relu, Conv2dInt4Param param, + Conv2dConstantOffset conv2d_constant) { +#ifdef SM80_ENABLED + const int stages = 3; + const uint32_t tid = threadIdx.x; + const uint32_t bidx = blockIdx.x; + const uint32_t bidy = blockIdx.y; + extern __shared__ int32_t smem[]; // (128+128)*128/8*stages + int2 reg_acc[reg_m][reg_n]; + int4 reg_src[2][reg_nd4]; + int4 reg_flt[2][reg_md4]; + // use in other way, maybe use reg_ser/flt + int4 reg_src_cache[2]; + int4 reg_filter_cache[4]; + + uint32_t tid127 = (tid & 127); + uint32_t section = (tid127 >> 1); + uint32_t residue = ((tid127 << 5) & 63); + uint32_t nhw = bidx * BN + section; + + uint32_t tn, hw, toh, tow; + int tih, tiw; + int h_start[2]; + int h_end[2]; + int w_start[2]; + int w_end[2]; + bool g[2]; + const int8_t* __restrict__ g_src_ptr[4]; + for (int i = 0; i < 2; i++) { + if (i != 0) { + nhw += 64; + } + tn = nhw / param.div_ohow; + hw = nhw % param.div_ohow; + toh = hw / param.div_ow; + tow = hw % param.div_ow; + tih = toh * param.sh - param.ph; + tiw = tow * param.sw - param.pw; + g[i] = tn < param.n; + h_start[i] = -tih; + h_end[i] = param.ih - tih; + w_start[i] = -tiw; + w_end[i] = param.iw - tiw; + // param's members have been converted to byte offset and int4 offset need to + // div 2 + int src_offset = tn * param.ibs + tih * param.ihs + + ((int)(tiw * packed_channel + residue) >> 1); + g_src_ptr[i * 2] = src + src_offset; + g_src_ptr[i * 2 + 1] = g_src_ptr[i * 2]; + } + + const uint32_t section_section = (section >> 2); + const uint32_t section_residue = (section & 3); + const uint32_t section_factor = ((section & 15) >> 2); + const uint32_t crosswise_offset = + ((section_residue >> 1) << 4) + + (((section_residue & 1) ^ (section_factor >> 1)) << 3); + const uint32_t residue_offset = ((residue >> 5) ^ (section_factor & 1)) << 2; + + // next + 64 * BK / 8 + int32_t* write_src_s[2]; + write_src_s[0] = + smem + section_section * BK / 2 + crosswise_offset + residue_offset; + write_src_s[1] = write_src_s[0] + 32; + + int iter = (param.icfhfw >> 6); + + uint32_t tid31 = (tid & 31); + uint32_t warp_idx = (tid >> 5); + uint32_t warp_strided = (warp_idx << 2); + uint32_t htid = (tid31 >> 4); + const uint32_t flt_strided = bidy * BM / 8 + warp_strided; + bool guard = flt_strided * 8 < param.oc && iter > htid; + // icfhfw * 8/2 is a stride + int8_t* __restrict__ g_filter_ptr0 = + filter + flt_strided * (param.icfhfw * 4) + (tid31 << 4); + int8_t* __restrict__ g_filter_ptr1 = g_filter_ptr0 + (param.icfhfw * 4); + int8_t* __restrict__ g_filter_ptr2 = g_filter_ptr0 + (param.icfhfw * 8); + int8_t* __restrict__ g_filter_ptr3 = g_filter_ptr0 + (param.icfhfw * 12); + // next + BK * 8 / (INT32/INT4) + uint32_t q = (tid31 >> 3); + uint32_t r = (tid31 & 7); + int32_t* write_flt_s = smem + BN * BK / 8 + warp_strided * BK + ((q & 1) << 6) + + ((q >> 1) << 5) + (r << 2); + uint32_t quad_idx = (tid31 >> 2); + uint32_t idx_in_quad = (tid & 3); + uint32_t quad_factor = ((tid & 15) >> 2); + uint32_t crosswise = + ((idx_in_quad >> 1) << 4) + (((idx_in_quad & 1) ^ (quad_factor >> 1)) << 3); + uint32_t warp_x = (warp_idx >> 1); + uint32_t warp_y = (warp_idx & 1); + + int32_t* read_src_s_0 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise + + ((0 ^ (quad_factor & 1)) << 2); + int32_t* read_src_s_1 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise + + ((1 ^ (quad_factor & 1)) << 2); + int32_t* read_flt_s_0 = smem + BN * BK / 8 + (warp_y * 8 * BK) + + (quad_idx * BK / 2) + crosswise + + ((0 ^ (quad_factor & 1)) << 2); + int32_t* read_flt_s_1 = smem + BN * BK / 8 + (warp_y * 8 * BK) + + (quad_idx * BK / 2) + crosswise + + ((1 ^ (quad_factor & 1)) << 2); + +#pragma unroll + for (int i = 0; i < reg_m; i++) { +#pragma unroll + for (int j = 0; j < reg_n; j++) { + reg_acc[i][j] = make_int2(0, 0); + } + } + + const int smem_switch = 4096; + const int smem_switch_back = -smem_switch * (stages - 1); + int stage = 0; + uint32_t offset[2] = {0, 2}; // high & low + int src_step[2], x[2], y[2]; + + // global mem --> shared mem, stage 0 + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + + bool guard0[2], guard1[2]; + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] && + y[0] < w_end[0] && iter > 0; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] && + y[1] < w_end[0] && iter > 1; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] && + y[0] < w_end[1] && iter > 0; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] && + y[1] < w_end[1] && iter > 1; + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + cp_async_fence(); + + // global mem --> shared mem, stage 1 -> stage n + iter -= 2; + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] && + y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] && + y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] && + y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] && + y[1] < w_end[1]; + + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + for (; iter >= 2 && stage < stages - 1; iter -= 2) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + + stage++; + cp_async_fence(); + + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && + y[0] >= w_start[0] && y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && + y[1] >= w_start[0] && y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && + y[0] >= w_start[1] && y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && + y[1] >= w_start[1] && y[1] < w_end[1]; + + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + } + bool is_copy = false; + + if (iter == 1 && stage != stages - 1) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard && iter > htid) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + is_copy = true; + cp_async_fence(); + } + + // compute offset + int d_offset = (bidy * (BM >> 6) + warp_y) * param.ocs + (idx_in_quad << 3); + section = tid31 >> 2; + size_t nhw_post0 = bidx * BN + warp_x * 64 + section; + size_t nhw_post1 = nhw_post0 + 8; + size_t nhw_post2 = nhw_post0 + 16; + size_t nhw_post3 = nhw_post0 + 24; + size_t stg_oc = bidy * BM + (warp_y << 6); + int* g_offset = ((int*)®_filter_cache); + bool stg_guard[8]; +#pragma unroll + for (int y = 0; y < reg_m; y += 4) { + COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y) + + nhw_post0 += 32; + nhw_post1 += 32; + nhw_post2 += 32; + nhw_post3 += 32; + } + + bool only_one_stage = (stage == 1) ? true : false; + if (stage >= 2) { + cp_async_wait(stages - 2); + } else { + cp_async_wait(0); + } + + __syncthreads(); + + for (; iter >= 2; iter -= 2) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1]) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0); + } + stage++; + cp_async_fence(); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + if (stage == stages) { + stage = 0; + write_src_s[0] += smem_switch_back; + write_src_s[1] += smem_switch_back; + write_flt_s += smem_switch_back; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + write_src_s[0] += smem_switch; + write_src_s[1] += smem_switch; + write_flt_s += smem_switch; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + + int src_step[2]; + for (int i = 0; i < 2; i++) { + src_step[i] = conv2d_constant.c_offset[offset[i]]; + uint32_t spatial = *(reinterpret_cast( + &(conv2d_constant.c_offset[offset[i] + 1]))); + x[i] = (spatial & 0xff); + y[i] = ((spatial >> 8) & 0xff); + if (offset[i] < conv2d_constant.c_offset_param.max) { + offset[i] += 4; + } else { + offset[i] += conv2d_constant.c_offset_param.rewind; + } + } + guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && + y[0] >= w_start[0] && y[0] < w_end[0]; + guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && + y[1] >= w_start[0] && y[1] < w_end[0]; + guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && + y[0] >= w_start[1] && y[0] < w_end[1]; + guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && + y[1] >= w_start[1] && y[1] < w_end[1]; + g_src_ptr[0] += src_step[0]; + g_src_ptr[1] += src_step[1]; + g_src_ptr[2] += src_step[0]; + g_src_ptr[3] += src_step[1]; + g_filter_ptr0 += 8 * 64; + g_filter_ptr1 += 8 * 64; + g_filter_ptr2 += 8 * 64; + g_filter_ptr3 += 8 * 64; + cp_async_wait(stages - 2); + __syncthreads(); + } + + if (iter > 0) { + if (!is_copy) { + if (guard0[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[0]), + reinterpret_cast(write_src_s[0])); + } else { + *(reinterpret_cast(write_src_s[0])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard0[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[1]), + reinterpret_cast(write_src_s[1])); + } else { + *(reinterpret_cast(write_src_s[1])) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[0]) { + g2s_int4( + reinterpret_cast(g_src_ptr[2]), + reinterpret_cast(write_src_s[0] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[0] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + if (guard1[1] && iter > 1) { + g2s_int4( + reinterpret_cast(g_src_ptr[3]), + reinterpret_cast(write_src_s[1] + 8 * BK)); + } else { + *(reinterpret_cast(write_src_s[1] + 8 * BK)) = make_int4( + pk_src_zero_point, pk_src_zero_point, pk_src_zero_point, + pk_src_zero_point); + } + + if (guard && iter > htid) { + g2s_int4( + reinterpret_cast(g_filter_ptr0), + reinterpret_cast(write_flt_s)); + g2s_int4( + reinterpret_cast(g_filter_ptr1), + reinterpret_cast(write_flt_s + BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr2), + reinterpret_cast(write_flt_s + 2 * BK)); + g2s_int4( + reinterpret_cast(g_filter_ptr3), + reinterpret_cast(write_flt_s + 3 * BK)); + } else { + *(reinterpret_cast(write_flt_s)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + BK)) = make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 2 * BK)) = + make_int4(0, 0, 0, 0); + *(reinterpret_cast(write_flt_s + 3 * BK)) = + make_int4(0, 0, 0, 0); + } + cp_async_fence(); + } +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + stage++; + if (stage == stages) { + stage = 0; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + cp_async_wait(stages - 2); + } + + if (!only_one_stage) { +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = + get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + } + + stage++; + if (stage == stages) { + stage = 0; + + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } else if (stage == stages - 1) { + read_src_s_0 += smem_switch_back; + read_src_s_1 += smem_switch_back; + read_flt_s_0 += smem_switch_back; + read_flt_s_1 += smem_switch_back; + } else { + read_src_s_0 += smem_switch; + read_src_s_1 += smem_switch; + read_flt_s_0 += smem_switch; + read_flt_s_1 += smem_switch; + } + cp_async_wait(0); + } + guard = iter < 0; +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[0][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[0][j] = make_int4(x, y, z, w); + } + +// compute +#pragma unroll + for (int k_inner = 0; k_inner < BKd32; k_inner++) { + int comp = (k_inner & 0x1); + int load = 1 - comp; + if (k_inner < BKd32 - 1 && !(k_inner == 1 && guard)) { + int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1; + int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1; + read_src_s += 32 * ((k_inner + 1) >> 1); + read_flt_s += 32 * ((k_inner + 1) >> 1); + +#pragma unroll // low + for (int i = 0; i < reg_nd4; ++i) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8 + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_src[load][i] = make_int4(x, y, z, w); + } + +#pragma unroll + for (int j = 0; j < reg_md4; ++j) { + int x, y, z, w; + unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK); + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, " + "%3}, " + "[%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(addr)); + reg_flt[load][j] = make_int4(x, y, z, w); + } + } + + int* A = reinterpret_cast(®_flt[comp][0]); + int* B = reinterpret_cast(®_src[comp][0]); +#pragma unroll + for (int x = 0; x < reg_n; x++) { +#pragma unroll + for (int y = 0; y < reg_m; y++) { + int* D = reinterpret_cast(®_acc[y][x]); + int* C = reinterpret_cast(®_acc[y][x]); + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4." + "s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1])); + } + } + if (k_inner == 1 && guard) { + break; + } + } + + __syncthreads(); + + /// output + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + mul_v4(load_bias0, load_bias0, beta); + mul_v4(load_bias1, load_bias1, beta); + mul_v4(load_bias2, load_bias2, beta); + mul_v4(load_bias3, load_bias3, beta); + } + + int8_t* __restrict__ g_dst_ptr = dst + d_offset; + +#pragma unroll + for (int y = 0; y < reg_m; y += 4) { + I2F_4x8(reg_acc, y, 0); + FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); + STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); + + nhw_post0 += 32; + nhw_post1 += 32; + nhw_post2 += 32; + nhw_post3 += 32; + } +#endif +} +} // namespace + +namespace megdnn { +namespace cuda { +namespace ptx { +void run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params) { +#ifdef SM80_SUPPORTED + cudaFuncSetAttribute( + ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu, + cudaFuncAttributeMaxDynamicSharedMemorySize, 49152); + + ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu<<< + grid, block, 49152, stream>>>( + *((int8_t**)params[0]), *((int8_t**)params[1]), *((float**)params[2]), + *((int8_t**)params[3]), *((float*)params[4]), *((float*)params[5]), + *((uint32_t*)params[6]), *((float*)params[7]), *((uint32_t*)params[8]), + *((Conv2dInt4Param*)params[9]), *((Conv2dConstantOffset*)params[10])); +#endif +} +} // namespace ptx +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/ptx/uint4_int4/kern.cuh b/dnn/src/cuda/ptx/uint4_int4/kern.cuh new file mode 100644 index 000000000..80defa3df --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/kern.cuh @@ -0,0 +1,20 @@ +#include + +namespace megdnn { +namespace cuda { +namespace ptx { +void run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +void run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +void run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu( + const dim3 grid, const dim3 block, cudaStream_t stream, void** params); +} // namespace ptx +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/ptx/uint4_int4/macro.cuh b/dnn/src/cuda/ptx/uint4_int4/macro.cuh new file mode 100644 index 000000000..812a0b150 --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/macro.cuh @@ -0,0 +1,348 @@ +#pragma once + +//! ============= i2f =============== +__device__ __forceinline__ void i2f(int2& a) { + ((float*)&a)[0] = static_cast(a.x); + ((float*)&a)[1] = static_cast(a.y); +} + +//! ============= mul =============== +template +__device__ __forceinline__ void mul_v4(int4& c, const int4 a, const T alpha); + +template <> +__device__ __forceinline__ void mul_v4( + int4& c, const int4 a, const float alpha) { + ((float*)&c)[0] = ((float*)&a)[0] * alpha; + ((float*)&c)[1] = ((float*)&a)[1] * alpha; + ((float*)&c)[2] = ((float*)&a)[2] * alpha; + ((float*)&c)[3] = ((float*)&a)[3] * alpha; +} + +//! ============= fma =============== +__device__ __forceinline__ void fma2( + int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, + const int4 b) { + asm("fma.rz.f32 %0, %1, %2, %3;" + : "=f"(((float*)&c0)[0]) + : "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0])); + asm("fma.rz.f32 %0, %1, %2, %3;" + : "=f"(((float*)&c0)[1]) + : "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1])); + asm("fma.rz.f32 %0, %1, %2, %3;" + : "=f"(((float*)&c1)[0]) + : "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2])); + asm("fma.rz.f32 %0, %1, %2, %3;" + : "=f"(((float*)&c1)[1]) + : "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3])); +} + +__device__ __forceinline__ void fuse_z_1x8( + int4* a, const int& j, const int4& fuse_z, const float& gamma, + const int32_t& zero_point) { + const int2 z[2] = { + *reinterpret_cast(&fuse_z), + *(reinterpret_cast(&fuse_z) + 1)}; + for (int k = 0; k < 4; k++) { + int f = ((z[0].x >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; + f = ((z[0].x >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; + + f = ((z[1].x >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[2] += (f - zero_point) * gamma; + f = ((z[1].x >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[3] += (f - zero_point) * gamma; + } + for (int k = 0; k < 4; k++) { + int f = ((z[0].y >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; + f = ((z[0].y >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; + + f = ((z[1].y >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma; + f = ((z[1].y >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma; + } +} + +__device__ __forceinline__ void fuse_z_1x8( + int2* a, const int& j, const int2& fuse_z, const float& gamma, + const int32_t& zero_point) { +#pragma unroll + for (int k = 0; k < 4; k++) { + int f = ((fuse_z.x >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; + f = ((fuse_z.x >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; + } +#pragma unroll + for (int k = 0; k < 4; k++) { + int f = ((fuse_z.y >> (k * 8)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; + f = ((fuse_z.y >> (k * 8 + 4)) & 15); + f = (f << 28) >> 28; + ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; + } +} + +__device__ __forceinline__ void pack_f2i( + int& d0, int& d1, const int4 s0, const int4 s1, const int4 s2, const int4 s3, + const uint32_t relu, float& dst_zero_point) { + // uint32_t ix, iy, iz, iw; + uint32_t x0, y0, z0, w0; + uint32_t x1, y1, z1, w1; + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x0) : "f"(((float*)&s0)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y0) : "f"(((float*)&s0)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z0) : "f"(((float*)&s1)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w0) : "f"(((float*)&s1)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x1) : "f"(((float*)&s2)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y1) : "f"(((float*)&s2)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z1) : "f"(((float*)&s3)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w1) : "f"(((float*)&s3)[1])); + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(d0) + : "r"(x0), "r"(y0), "r"(z0), "r"(w0), "r"(x1), "r"(y1), "r"(z1), "r"(w1)); + + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x0) : "f"(((float*)&s0)[2])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y0) : "f"(((float*)&s0)[3])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z0) : "f"(((float*)&s1)[2])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w0) : "f"(((float*)&s1)[3])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x1) : "f"(((float*)&s2)[2])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y1) : "f"(((float*)&s2)[3])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z1) : "f"(((float*)&s3)[2])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w1) : "f"(((float*)&s3)[3])); + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(d1) + : "r"(x0), "r"(y0), "r"(z0), "r"(w0), "r"(x1), "r"(y1), "r"(z1), "r"(w1)); +} + +__device__ __forceinline__ void pack_f2i_with_relu( + int& d0, const int2 s0, const int2 s1, const int2 s2, const int2 s3, + const uint32_t relu, float& dst_zero_point) { + uint32_t x[8]; + + if (relu > 0) { + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[0]) : "f"(((float*)&s0)[0])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[1]) : "f"(((float*)&s0)[1])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[2]) : "f"(((float*)&s1)[0])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[3]) : "f"(((float*)&s1)[1])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[4]) : "f"(((float*)&s2)[0])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[5]) : "f"(((float*)&s2)[1])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[6]) : "f"(((float*)&s3)[0])); + asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[7]) : "f"(((float*)&s3)[1])); + x[0] += dst_zero_point; + x[1] += dst_zero_point; + x[2] += dst_zero_point; + x[3] += dst_zero_point; + x[4] += dst_zero_point; + x[5] += dst_zero_point; + x[6] += dst_zero_point; + x[7] += dst_zero_point; + } else if (relu == 0) { + ((float*)&s0)[0] += dst_zero_point; + ((float*)&s0)[1] += dst_zero_point; + ((float*)&s1)[0] += dst_zero_point; + ((float*)&s1)[1] += dst_zero_point; + ((float*)&s2)[0] += dst_zero_point; + ((float*)&s2)[1] += dst_zero_point; + ((float*)&s3)[0] += dst_zero_point; + ((float*)&s3)[1] += dst_zero_point; + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[0]) : "f"(((float*)&s0)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[1]) : "f"(((float*)&s0)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[2]) : "f"(((float*)&s1)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[3]) : "f"(((float*)&s1)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[4]) : "f"(((float*)&s2)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[5]) : "f"(((float*)&s2)[1])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[6]) : "f"(((float*)&s3)[0])); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[7]) : "f"(((float*)&s3)[1])); + } + + if (relu > 1) { + int r1, r2; + r1 = (x[0] >= relu); + x[0] *= r1; + r2 = (x[1] >= relu); + x[1] *= r2; + r1 = (x[2] >= relu); + x[2] *= r1; + r2 = (x[3] >= relu); + x[3] *= r2; + r1 = (x[4] >= relu); + x[4] *= r1; + r2 = (x[5] >= relu); + x[5] *= r2; + r1 = (x[6] >= relu); + x[6] *= r1; + r2 = (x[7] >= relu); + x[7] *= r2; + } + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(d0) + : "r"(x[0]), "r"(x[1]), "r"(x[2]), "r"(x[3]), "r"(x[4]), "r"(x[5]), + "r"(x[6]), "r"(x[7])); +} + +#define I2F_1x8(a, i, j) \ + i2f(a[i][j]); \ + i2f(a[i][j + 1]); \ + i2f(a[i][j + 2]); \ + i2f(a[i][j + 3]); \ + i2f(a[i][j + 4]); \ + i2f(a[i][j + 5]); \ + i2f(a[i][j + 6]); \ + i2f(a[i][j + 7]); + +#define I2F_4x8(a, i, j) \ + I2F_1x8(a, i, j) I2F_1x8(a, i + 1, j) I2F_1x8(a, i + 2, j) I2F_1x8(a, i + 3, j) + +#define FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \ + fma2(a[i][j], reg_acc[i][j], a[i][j + 1], reg_acc[i][j + 1], alpha, bias0); \ + fma2(a[i][j + 2], reg_acc[i][j + 2], a[i][j + 3], reg_acc[i][j + 3], alpha, \ + bias1); \ + fma2(a[i][j + 4], reg_acc[i][j + 4], a[i][j + 5], reg_acc[i][j + 5], alpha, \ + bias2); \ + fma2(a[i][j + 6], reg_acc[i][j + 6], a[i][j + 7], reg_acc[i][j + 7], alpha, bias3); + +#define FMA_4x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \ + FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \ + FMA_1x8(a, i + 1, j, alpha, bias0, bias1, bias2, bias3) \ + FMA_1x8(a, i + 2, j, alpha, bias0, bias1, bias2, bias3) \ + FMA_1x8(a, i + 3, j, alpha, bias0, bias1, bias2, bias3) + +// pack 1x(8 int2) to int2 +#define PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \ + pack_f2i_with_relu( \ + a[i][j].x, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3], relu, \ + dst_zero_point); \ + pack_f2i_with_relu( \ + a[i][j].y, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7], relu, \ + dst_zero_point); + +// pack 4x8 int2 float to 4 int2 +#define PACK_F2I_WITH_RELU_4x8(a, i, j, relu, dst_zero_point) \ + PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \ + PACK_F2I_WITH_RELU_1x8(a, i + 1, j, relu, dst_zero_point) \ + PACK_F2I_WITH_RELU_1x8(a, i + 2, j, relu, dst_zero_point) \ + PACK_F2I_WITH_RELU_1x8(a, i + 3, j, relu, dst_zero_point) + +#define STG(d, s, idx, n_reuse, hw_reuse, g) \ + n_reuse = nhw_post##idx / param.div_ohow; \ + hw_reuse = nhw_post##idx % param.div_ohow; \ + d = g_dst_ptr + n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ + g = nhw_post##idx < param.nhw; \ + if (stg_oc < param.oc && g) { \ + *(reinterpret_cast(d)) = *(reinterpret_cast(&s)); \ + } + +#define STG_4x1(d, a, i, j) \ + STG(d[0], a[i][j], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ + STG(d[1], a[i + 1][j], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ + stg_guard[i + 1]) \ + STG(d[2], a[i + 2][j], 2, reg_src_cache[0].z, reg_src_cache[1].z, \ + stg_guard[i + 2]) \ + STG(d[3], a[i + 3][j], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) + +#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \ + fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \ + fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \ + fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ + fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); + +#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \ + fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \ + fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \ + fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ + fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); + +// 1x8 1x(2x8 int2) to 2 int2 +#define PACK_F2I_1x8(a, i, j) \ + pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ + pack_f2i(a[i][j].y, a[i][j].w, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7]); + +// 4x8 int4 +#define PACK_F2I_4x8(a, i, j) \ + PACK_F2I_1x8(a, i, j) PACK_F2I_1x8(a, i + 1, j) PACK_F2I_1x8(a, i + 2, j) \ + PACK_F2I_1x8(a, i + 3, j) + +#define LDG(d, s, idx, n_reuse, hw_reuse, g) \ + n_reuse = nhw_post##idx / param.div_ohow; \ + hw_reuse = nhw_post##idx % param.div_ohow; \ + s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ + g = nhw_post##idx < param.nhw; \ + if (stg_oc < param.oc && g) { \ + *(reinterpret_cast(&d)) = \ + *(reinterpret_cast(g_z_ptr + s)); \ + } + +#define LDG_4x1(d, s, i) \ + LDG(d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ + LDG(d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ + stg_guard[i + 1]) \ + LDG(d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \ + stg_guard[i + 2]) \ + LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) + +#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \ + n_reuse = nhw_post##idx / param.div_ohow; \ + hw_reuse = nhw_post##idx % param.div_ohow; \ + s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ + g = nhw_post##idx < param.nhw; + +#define COMPUTE_OFFSET_4x1(d, s, i) \ + COMPUTE_OFFSET( \ + d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ + COMPUTE_OFFSET( \ + d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ + stg_guard[i + 1]) \ + COMPUTE_OFFSET( \ + d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \ + stg_guard[i + 2]) \ + COMPUTE_OFFSET( \ + d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \ + stg_guard[i + 3]) + +#define STG_AFTER_LDG(d, s, g) \ + if (stg_oc < param.oc && g) { \ + *(reinterpret_cast(g_dst_ptr + d)) = *(reinterpret_cast(&s)); \ + } + +#define STG_AFTER_LDG_4x1(d, a, i, j) \ + STG_AFTER_LDG(d[i], a[i][j], stg_guard[i]) \ + STG_AFTER_LDG(d[i + 1], a[i + 1][j], stg_guard[i + 1]) \ + STG_AFTER_LDG(d[i + 2], a[i + 2][j], stg_guard[i + 2]) \ + STG_AFTER_LDG(d[i + 3], a[i + 3][j], stg_guard[i + 3]) +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/ptx/uint4_int4/tools.cuh b/dnn/src/cuda/ptx/uint4_int4/tools.cuh new file mode 100644 index 000000000..49d82ad25 --- /dev/null +++ b/dnn/src/cuda/ptx/uint4_int4/tools.cuh @@ -0,0 +1,49 @@ +#include + +#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) +extern "C" { +// +// This NVVM intrinsic is subject to change in future versions of CUDA. +// Clients should not call it directly. Rather, they should use the +// cutlass::arch::ldsm<>() template. +// +__device__ uint32_t __nvvm_get_smem_pointer(void*); +} +#endif + +inline __device__ unsigned get_smem_pointer(void* ptr) { +#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only + // available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTLASS helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && \ + __CUDACC_VER_MINOR__ >= 2) + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + return smem_ptr; + +#else + + return 0; +#endif +} -- GitLab