提交 1f8e930e 编写于 作者: M Megvii Engine Team

feat(cuda): add int4 ptx 128x128 mma kernel

GitOrigin-RevId: 5a8b9c3f8eab59ed8d1daf9bbaf2c81cdc82ca5b
上级 1a2ed8c4
#include "./base.cuh"
using namespace convolution;
Uint32Fastdiv::Uint32Fastdiv() {
memset(this, 0, sizeof(Uint32Fastdiv));
}
Uint32Fastdiv& Uint32Fastdiv::operator=(uint32_t d) {
m_divisor = d;
constexpr uint32_t MAX_U32 = ~0u;
m_inc_dividend = 0;
m_divisor_is_not_1 = ~0u;
if (!(d & (d - 1))) {
// power of 2
m_mul = 1u << 31;
int p = 0;
while ((1u << p) < d)
++p;
m_shift = p ? p - 1 : 0;
if (d == 1)
m_divisor_is_not_1 = 0;
return *this;
}
auto n_bound = uint64_t(d / 2 + 1) * MAX_U32;
uint32_t shift = 32;
while ((1ull << shift) < n_bound)
++shift;
uint64_t mdst = 1ull << shift;
int64_t delta = d - mdst % d;
m_mul = mdst / d + 1;
if ((uint64_t)delta > d / 2) {
delta -= d;
--m_mul;
m_inc_dividend = 1;
}
m_shift = shift - 32;
return *this;
}
#pragma once
#include <cuda_runtime.h>
#include <stdint.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || \
(__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#define SM80_SUPPORTED
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define SM80_ENABLED
#endif
#endif
namespace convolution {
class Uint32Fastdiv {
uint32_t m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift;
public:
Uint32Fastdiv();
Uint32Fastdiv(uint32_t d) { operator=(d); }
//! set the divisor to be d
Uint32Fastdiv& operator=(uint32_t d);
//! caller must ensure that dividend would not exceed this number
static constexpr uint32_t MAX_DIVIDEND = ~0u - 1;
__device__ __forceinline__ uint32_t divisor() const { return m_divisor; }
__device__ __forceinline__ uint32_t divide(uint32_t dividend) const {
uint32_t ans_for_one = dividend & ~m_divisor_is_not_1,
dfix = dividend + m_inc_dividend,
#if __CUDA_ARCH__
hi32 = __umulhi(dfix, m_mul),
#else
hi32 = ((uint64_t)dfix * m_mul) >> 32,
#endif
ans = hi32 >> m_shift;
return (ans & m_divisor_is_not_1) | ans_for_one;
}
};
static __forceinline__ __device__ uint32_t
operator/(uint32_t a, const Uint32Fastdiv& d) {
return d.divide(a);
}
static __forceinline__ __device__ uint32_t
operator%(uint32_t a, const Uint32Fastdiv& d) {
return a - d.divisor() * d.divide(a);
}
struct Conv2dInt4Param {
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow;
uint32_t ibs, ics, ihs;
uint32_t obs, ocs, ohs;
uint32_t icfhfw;
uint32_t nhw;
Uint32Fastdiv div_ohow;
Uint32Fastdiv div_ow;
Conv2dInt4Param(
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw,
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc,
uint32_t oh, uint32_t ow, uint32_t interleaved)
: n(n),
ic(ic),
ih(ih),
iw(iw),
fh(fh),
fw(fw),
sh(sh),
sw(sw),
ph(ph),
pw(pw),
oc(oc),
oh(oh),
ow(ow) {
constexpr uint32_t size_bits = 4;
// all stride size in bytes
ibs = ic * ih * iw * size_bits / 8;
ics = ih * iw * interleaved * size_bits / 8;
ihs = iw * interleaved * size_bits / 8;
obs = oc * oh * ow * size_bits / 8;
ocs = oh * ow * interleaved * size_bits / 8;
ohs = ow * interleaved * size_bits / 8;
icfhfw = ic * fh * fw;
nhw = n * oh * ow;
div_ohow = oh * ow;
div_ow = ow;
}
};
struct Conv2dConstantOffsetParam {
int32_t begin;
int32_t size;
int32_t max;
int32_t rewind;
};
#define CONSTANT_BUFFER_SIZE 848
struct Conv2dConstantOffset {
Conv2dConstantOffsetParam c_offset_param;
int c_offset[CONSTANT_BUFFER_SIZE];
};
} // namespace convolution
#include <cuda_runtime.h>
#include <stdio.h>
#include "./imma8832_128x128.cuh"
#include "./kern.cuh"
#include "./macro.cuh"
#include "./tools.cuh"
using namespace convolution;
namespace {
#ifdef SM80_ENABLED
extern "C" __device__ void g2s_int4(const int4* gm, int4* sm) {
unsigned sm_addr = get_smem_pointer(sm);
const int SizeInBytes = 16;
#if ENABLE_L2_PREFETCH
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(sm_addr),
"l"(gm), "n"(SizeInBytes));
#else
asm volatile(
"cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(sm_addr), "l"(gm),
"n"(SizeInBytes));
#endif
}
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
/// not block.
#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::)
/// Blocks until all but <N> previous cp.async.commit_group operations have
/// committed.
#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#endif
extern "C" __global__ void __launch_bounds__(256)
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu(
const int8_t* __restrict__ src, const int8_t* __restrict__ filter,
const float* __restrict__ bias, const int8_t* __restrict__ z,
int8_t* __restrict__ dst, float alpha, float beta, float gamma,
uint32_t pk_src_zero_point, int32_t z_zero_point, float dst_zero_point,
uint32_t relu, Conv2dInt4Param param,
Conv2dConstantOffset conv2d_constant) {
#ifdef SM80_ENABLED
const int stages = 3;
const uint32_t tid = threadIdx.x;
const uint32_t bidx = blockIdx.x;
const uint32_t bidy = blockIdx.y;
extern __shared__ int32_t smem[]; // (128 + 128)*128/8*stages
int2 reg_acc[reg_m][reg_n];
int4 reg_src[2][reg_nd4];
int4 reg_flt[2][reg_md4];
// use in other way, maybe use reg_ser/flt
int4 reg_src_cache[2];
int4 reg_filter_cache[4];
uint32_t tid127 = (tid & 127);
uint32_t section = (tid127 >> 1);
uint32_t residue = ((tid127 << 5) & 63);
uint32_t nhw = bidx * BN + section;
uint32_t tn, hw, toh, tow;
int tih, tiw;
int h_start[2];
int h_end[2];
int w_start[2];
int w_end[2];
bool g[2];
const int8_t* __restrict__ g_src_ptr[4];
for (int i = 0; i < 2; i++) {
if (i != 0) {
nhw += 64;
}
tn = nhw / param.div_ohow;
hw = nhw % param.div_ohow;
toh = hw / param.div_ow;
tow = hw % param.div_ow;
tih = toh * param.sh - param.ph;
tiw = tow * param.sw - param.pw;
g[i] = tn < param.n;
h_start[i] = -tih;
h_end[i] = param.ih - tih;
w_start[i] = -tiw;
w_end[i] = param.iw - tiw;
// param's members have been converted to byte offset and int4 offset need to
// div 2
int src_offset = tn * param.ibs + tih * param.ihs +
((int)(tiw * packed_channel + residue) >> 1);
g_src_ptr[i * 2] = src + src_offset;
g_src_ptr[i * 2 + 1] = g_src_ptr[i * 2];
}
const uint32_t section_section = (section >> 2);
const uint32_t section_residue = (section & 3);
const uint32_t section_factor = ((section & 15) >> 2);
const uint32_t crosswise_offset =
((section_residue >> 1) << 4) +
(((section_residue & 1) ^ (section_factor >> 1)) << 3);
const uint32_t residue_offset = ((residue >> 5) ^ (section_factor & 1)) << 2;
// next + 64 * BK / 8
int32_t* write_src_s[2];
write_src_s[0] =
smem + section_section * BK / 2 + crosswise_offset + residue_offset;
write_src_s[1] = write_src_s[0] + 32;
int iter = (param.icfhfw >> 6);
uint32_t tid31 = (tid & 31);
uint32_t warp_idx = (tid >> 5);
uint32_t warp_strided = (warp_idx << 2);
uint32_t htid = (tid31 >> 4);
const uint32_t flt_strided = bidy * BM / 8 + warp_strided;
bool guard = flt_strided * 8 < param.oc && iter > htid;
// icfhfw * 8/2 is a stride
const int8_t* __restrict__ g_filter_ptr0 =
filter + flt_strided * (param.icfhfw * 4) + (tid31 << 4);
const int8_t* __restrict__ g_filter_ptr1 = g_filter_ptr0 + (param.icfhfw * 4);
const int8_t* __restrict__ g_filter_ptr2 = g_filter_ptr0 + (param.icfhfw * 8);
const int8_t* __restrict__ g_filter_ptr3 = g_filter_ptr0 + (param.icfhfw * 12);
// next + BK * 8 / (INT32/INT4)
uint32_t q = (tid31 >> 3);
uint32_t r = (tid31 & 7);
int32_t* write_flt_s = smem + BN * BK / 8 + warp_strided * BK + ((q & 1) << 6) +
((q >> 1) << 5) + (r << 2);
uint32_t quad_idx = (tid31 >> 2);
uint32_t idx_in_quad = (tid & 3);
uint32_t quad_factor = ((tid & 15) >> 2);
uint32_t crosswise =
((idx_in_quad >> 1) << 4) + (((idx_in_quad & 1) ^ (quad_factor >> 1)) << 3);
uint32_t warp_x = (warp_idx >> 1);
uint32_t warp_y = (warp_idx & 1);
int32_t* read_src_s_0 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise +
((0 ^ (quad_factor & 1)) << 2);
int32_t* read_src_s_1 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise +
((1 ^ (quad_factor & 1)) << 2);
int32_t* read_flt_s_0 = smem + BN * BK / 8 + (warp_y * 8 * BK) +
(quad_idx * BK / 2) + crosswise +
((0 ^ (quad_factor & 1)) << 2);
int32_t* read_flt_s_1 = smem + BN * BK / 8 + (warp_y * 8 * BK) +
(quad_idx * BK / 2) + crosswise +
((1 ^ (quad_factor & 1)) << 2);
#pragma unroll
for (int i = 0; i < reg_m; i++) {
#pragma unroll
for (int j = 0; j < reg_n; j++) {
reg_acc[i][j] = make_int2(0, 0);
}
}
const int smem_switch = 4096;
const int smem_switch_back = -smem_switch * (stages - 1);
int stage = 0;
uint32_t offset[2] = {0, 2}; // high & low
int src_step[2], x[2], y[2];
// global mem --> shared mem, stage 0
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
bool guard0[2], guard1[2];
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] &&
y[0] < w_end[0] && iter > 0;
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] &&
y[1] < w_end[0] && iter > 1;
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] &&
y[0] < w_end[1] && iter > 0;
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] &&
y[1] < w_end[1] && iter > 1;
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
// global mem --> shared mem, stage 1 -> stage n
iter -= 2;
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] &&
y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] &&
y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] &&
y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] &&
y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
for (; iter >= 2 && stage < stages - 1; iter -= 2) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] &&
y[0] >= w_start[0] && y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] &&
y[1] >= w_start[0] && y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] &&
y[0] >= w_start[1] && y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] &&
y[1] >= w_start[1] && y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
}
bool is_copy = false;
if (iter == 1 && stage != stages - 1) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard && iter > htid) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
is_copy = true;
cp_async_fence();
}
bool only_one_stage = (stage == 1) ? true : false;
if (stage >= 2) {
cp_async_wait(stages - 2);
} else {
cp_async_wait(0);
}
__syncthreads();
// read fuse_z
int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point)};
int d_offset = (bidy * (BM >> 6) + warp_y) * param.ocs + (idx_in_quad << 3);
const int8_t* __restrict__ g_z_ptr = z + d_offset;
section = tid31 >> 2;
size_t nhw_post0 = bidx * BN + warp_x * 64 + section;
size_t nhw_post1 = nhw_post0 + 8;
size_t nhw_post2 = nhw_post0 + 16;
size_t nhw_post3 = nhw_post0 + 24;
size_t stg_oc = bidy * BM + (warp_y << 6);
int* g_offset = ((int*)&reg_filter_cache);
bool stg_guard[8];
#pragma unroll
for (int y = 0; y < reg_m; y += 4) {
LDG_4x1(reg_fuse_z, g_offset, y)
nhw_post0 += 32;
nhw_post1 += 32;
nhw_post2 += 32;
nhw_post3 += 32;
}
for (; iter >= 2; iter -= 2) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
if (stage == stages) {
stage = 0;
write_src_s[0] += smem_switch_back;
write_src_s[1] += smem_switch_back;
write_flt_s += smem_switch_back;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
int src_step[2];
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] &&
y[0] >= w_start[0] && y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] &&
y[1] >= w_start[0] && y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] &&
y[0] >= w_start[1] && y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] &&
y[1] >= w_start[1] && y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
cp_async_wait(stages - 2);
__syncthreads();
}
if (iter > 0) {
if (!is_copy) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard && iter > htid) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) =
make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) =
make_int4(0, 0, 0, 0);
}
cp_async_fence();
}
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
stage++;
if (stage == stages) {
stage = 0;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
cp_async_wait(stages - 2);
}
if (!only_one_stage) {
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
stage++;
if (stage == stages) {
stage = 0;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
cp_async_wait(0);
}
guard = iter < 0;
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
// compute
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1 && !(k_inner == 1 && guard)) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
if (k_inner == 1 && guard) {
break;
}
}
__syncthreads();
/// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
int8_t* __restrict__ g_dst_ptr = dst + d_offset;
#pragma unroll
for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
}
#endif
}
} // namespace
namespace megdnn {
namespace cuda {
namespace ptx {
void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params) {
#ifdef SM80_SUPPORTED
cudaFuncSetAttribute(
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu,
cudaFuncAttributeMaxDynamicSharedMemorySize, 49152);
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu<<<
grid, block, 49152, stream>>>(
*((int8_t**)params[0]), *((int8_t**)params[1]), *((float**)params[2]),
*((int8_t**)params[3]), *((int8_t**)params[4]), *((float*)params[5]),
*((float*)params[6]), *((float*)params[7]), *((uint32_t*)params[8]),
*((uint32_t*)params[9]), *((float*)params[10]), *((uint32_t*)params[11]),
*((Conv2dInt4Param*)params[12]), *((Conv2dConstantOffset*)params[13]));
#endif
}
} // namespace ptx
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "./base.cuh"
#define TX 128
#define TY 1
#define BM 128
#define BN 128
#define BK 128
#define mma_m 16
#define mma_n 8
#define mma_k 64
#define reg_m 8
#define reg_n 8
#define packed_channel 64
#define BKd32 (BK / 32)
#define BKd64 (BK / 64)
#define reg_md4 (reg_m >> 2)
#define WARPS (TX / 32)
#define cache_per_warp 128
#define reg_nd4 (reg_n >> 2)
#define ldg_src (BN * BK / (16 * TX))
#define ldg_filter (BM * BK / (16 * TX))
#define ldg_width 16
// vim: syntax=cpp.doxygen
#include <cuda_runtime.h>
#include <stdio.h>
#include "./imma8832_128x128.cuh"
#include "./kern.cuh"
#include "./macro.cuh"
#include "./tools.cuh"
using namespace convolution;
namespace {
#ifdef SM80_ENABLED
extern "C" __device__ void g2s_int4(const int4* gm, int4* sm) {
unsigned sm_addr = get_smem_pointer(sm);
const int SizeInBytes = 16;
#if ENABLE_L2_PREFETCH
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(sm_addr),
"l"(gm), "n"(SizeInBytes));
#else
asm volatile(
"cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(sm_addr), "l"(gm),
"n"(SizeInBytes));
#endif
}
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
/// not block.
#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::)
/// Blocks until all but <N> previous cp.async.commit_group operations have
/// committed.
#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#endif
extern "C" __global__ void __launch_bounds__(256)
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu(
const int8_t* __restrict__ src, int8_t* __restrict__ filter,
const float* __restrict__ bias, int8_t* __restrict__ dst, float alpha,
float beta, uint32_t pk_src_zero_point, float dst_zero_point,
uint32_t relu, Conv2dInt4Param param,
Conv2dConstantOffset conv2d_constant) {
#ifdef SM80_ENABLED
const int stages = 3;
const uint32_t tid = threadIdx.x;
const uint32_t bidx = blockIdx.x;
const uint32_t bidy = blockIdx.y;
extern __shared__ int32_t smem[]; // (128+128)*128/8*stages
int2 reg_acc[reg_m][reg_n];
int4 reg_src[2][reg_nd4];
int4 reg_flt[2][reg_md4];
// use in other way, maybe use reg_ser/flt
int4 reg_src_cache[2];
int4 reg_filter_cache[4];
uint32_t tid127 = (tid & 127);
uint32_t section = (tid127 >> 1);
uint32_t residue = ((tid127 << 5) & 63);
uint32_t nhw = bidx * BN + section;
uint32_t tn, hw, toh, tow;
int tih, tiw;
int h_start[2];
int h_end[2];
int w_start[2];
int w_end[2];
bool g[2];
const int8_t* __restrict__ g_src_ptr[4];
for (int i = 0; i < 2; i++) {
if (i != 0) {
nhw += 64;
}
tn = nhw / param.div_ohow;
hw = nhw % param.div_ohow;
toh = hw / param.div_ow;
tow = hw % param.div_ow;
tih = toh * param.sh - param.ph;
tiw = tow * param.sw - param.pw;
g[i] = tn < param.n;
h_start[i] = -tih;
h_end[i] = param.ih - tih;
w_start[i] = -tiw;
w_end[i] = param.iw - tiw;
// param's members have been converted to byte offset and int4 offset need to
// div 2
int src_offset = tn * param.ibs + tih * param.ihs +
((int)(tiw * packed_channel + residue) >> 1);
g_src_ptr[i * 2] = src + src_offset;
g_src_ptr[i * 2 + 1] = g_src_ptr[i * 2];
}
const uint32_t section_section = (section >> 2);
const uint32_t section_residue = (section & 3);
const uint32_t section_factor = ((section & 15) >> 2);
const uint32_t crosswise_offset =
((section_residue >> 1) << 4) +
(((section_residue & 1) ^ (section_factor >> 1)) << 3);
const uint32_t residue_offset = ((residue >> 5) ^ (section_factor & 1)) << 2;
// next + 64 * BK / 8
int32_t* write_src_s[2];
write_src_s[0] =
smem + section_section * BK / 2 + crosswise_offset + residue_offset;
write_src_s[1] = write_src_s[0] + 32;
int iter = (param.icfhfw >> 6);
uint32_t tid31 = (tid & 31);
uint32_t warp_idx = (tid >> 5);
uint32_t warp_strided = (warp_idx << 2);
uint32_t htid = (tid31 >> 4);
const uint32_t flt_strided = bidy * BM / 8 + warp_strided;
bool guard = flt_strided * 8 < param.oc && iter > htid;
// icfhfw * 8/2 is a stride
int8_t* __restrict__ g_filter_ptr0 =
filter + flt_strided * (param.icfhfw * 4) + (tid31 << 4);
int8_t* __restrict__ g_filter_ptr1 = g_filter_ptr0 + (param.icfhfw * 4);
int8_t* __restrict__ g_filter_ptr2 = g_filter_ptr0 + (param.icfhfw * 8);
int8_t* __restrict__ g_filter_ptr3 = g_filter_ptr0 + (param.icfhfw * 12);
// next + BK * 8 / (INT32/INT4)
uint32_t q = (tid31 >> 3);
uint32_t r = (tid31 & 7);
int32_t* write_flt_s = smem + BN * BK / 8 + warp_strided * BK + ((q & 1) << 6) +
((q >> 1) << 5) + (r << 2);
uint32_t quad_idx = (tid31 >> 2);
uint32_t idx_in_quad = (tid & 3);
uint32_t quad_factor = ((tid & 15) >> 2);
uint32_t crosswise =
((idx_in_quad >> 1) << 4) + (((idx_in_quad & 1) ^ (quad_factor >> 1)) << 3);
uint32_t warp_x = (warp_idx >> 1);
uint32_t warp_y = (warp_idx & 1);
int32_t* read_src_s_0 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise +
((0 ^ (quad_factor & 1)) << 2);
int32_t* read_src_s_1 = smem + (warp_x * 8 * BK) + (quad_idx * BK / 2) + crosswise +
((1 ^ (quad_factor & 1)) << 2);
int32_t* read_flt_s_0 = smem + BN * BK / 8 + (warp_y * 8 * BK) +
(quad_idx * BK / 2) + crosswise +
((0 ^ (quad_factor & 1)) << 2);
int32_t* read_flt_s_1 = smem + BN * BK / 8 + (warp_y * 8 * BK) +
(quad_idx * BK / 2) + crosswise +
((1 ^ (quad_factor & 1)) << 2);
#pragma unroll
for (int i = 0; i < reg_m; i++) {
#pragma unroll
for (int j = 0; j < reg_n; j++) {
reg_acc[i][j] = make_int2(0, 0);
}
}
const int smem_switch = 4096;
const int smem_switch_back = -smem_switch * (stages - 1);
int stage = 0;
uint32_t offset[2] = {0, 2}; // high & low
int src_step[2], x[2], y[2];
// global mem --> shared mem, stage 0
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
bool guard0[2], guard1[2];
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] &&
y[0] < w_end[0] && iter > 0;
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] &&
y[1] < w_end[0] && iter > 1;
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] &&
y[0] < w_end[1] && iter > 0;
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] &&
y[1] < w_end[1] && iter > 1;
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
// global mem --> shared mem, stage 1 -> stage n
iter -= 2;
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] && y[0] >= w_start[0] &&
y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] && y[1] >= w_start[0] &&
y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] && y[0] >= w_start[1] &&
y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] && y[1] >= w_start[1] &&
y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
for (; iter >= 2 && stage < stages - 1; iter -= 2) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] &&
y[0] >= w_start[0] && y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] &&
y[1] >= w_start[0] && y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] &&
y[0] >= w_start[1] && y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] &&
y[1] >= w_start[1] && y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
}
bool is_copy = false;
if (iter == 1 && stage != stages - 1) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard && iter > htid) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
is_copy = true;
cp_async_fence();
}
// compute offset
int d_offset = (bidy * (BM >> 6) + warp_y) * param.ocs + (idx_in_quad << 3);
section = tid31 >> 2;
size_t nhw_post0 = bidx * BN + warp_x * 64 + section;
size_t nhw_post1 = nhw_post0 + 8;
size_t nhw_post2 = nhw_post0 + 16;
size_t nhw_post3 = nhw_post0 + 24;
size_t stg_oc = bidy * BM + (warp_y << 6);
int* g_offset = ((int*)&reg_filter_cache);
bool stg_guard[8];
#pragma unroll
for (int y = 0; y < reg_m; y += 4) {
COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y)
nhw_post0 += 32;
nhw_post1 += 32;
nhw_post2 += 32;
nhw_post3 += 32;
}
bool only_one_stage = (stage == 1) ? true : false;
if (stage >= 2) {
cp_async_wait(stages - 2);
} else {
cp_async_wait(0);
}
__syncthreads();
for (; iter >= 2; iter -= 2) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) = make_int4(0, 0, 0, 0);
}
stage++;
cp_async_fence();
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
if (stage == stages) {
stage = 0;
write_src_s[0] += smem_switch_back;
write_src_s[1] += smem_switch_back;
write_flt_s += smem_switch_back;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
write_src_s[0] += smem_switch;
write_src_s[1] += smem_switch;
write_flt_s += smem_switch;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
int src_step[2];
for (int i = 0; i < 2; i++) {
src_step[i] = conv2d_constant.c_offset[offset[i]];
uint32_t spatial = *(reinterpret_cast<const uint32_t*>(
&(conv2d_constant.c_offset[offset[i] + 1])));
x[i] = (spatial & 0xff);
y[i] = ((spatial >> 8) & 0xff);
if (offset[i] < conv2d_constant.c_offset_param.max) {
offset[i] += 4;
} else {
offset[i] += conv2d_constant.c_offset_param.rewind;
}
}
guard0[0] = g[0] && x[0] >= h_start[0] && x[0] < h_end[0] &&
y[0] >= w_start[0] && y[0] < w_end[0];
guard0[1] = g[0] && x[1] >= h_start[0] && x[1] < h_end[0] &&
y[1] >= w_start[0] && y[1] < w_end[0];
guard1[0] = g[1] && x[0] >= h_start[1] && x[0] < h_end[1] &&
y[0] >= w_start[1] && y[0] < w_end[1];
guard1[1] = g[1] && x[1] >= h_start[1] && x[1] < h_end[1] &&
y[1] >= w_start[1] && y[1] < w_end[1];
g_src_ptr[0] += src_step[0];
g_src_ptr[1] += src_step[1];
g_src_ptr[2] += src_step[0];
g_src_ptr[3] += src_step[1];
g_filter_ptr0 += 8 * 64;
g_filter_ptr1 += 8 * 64;
g_filter_ptr2 += 8 * 64;
g_filter_ptr3 += 8 * 64;
cp_async_wait(stages - 2);
__syncthreads();
}
if (iter > 0) {
if (!is_copy) {
if (guard0[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[0]),
reinterpret_cast<int4*>(write_src_s[0]));
} else {
*(reinterpret_cast<int4*>(write_src_s[0])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard0[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[1]),
reinterpret_cast<int4*>(write_src_s[1]));
} else {
*(reinterpret_cast<int4*>(write_src_s[1])) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[0]) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[2]),
reinterpret_cast<int4*>(write_src_s[0] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[0] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard1[1] && iter > 1) {
g2s_int4(
reinterpret_cast<const int4*>(g_src_ptr[3]),
reinterpret_cast<int4*>(write_src_s[1] + 8 * BK));
} else {
*(reinterpret_cast<int4*>(write_src_s[1] + 8 * BK)) = make_int4(
pk_src_zero_point, pk_src_zero_point, pk_src_zero_point,
pk_src_zero_point);
}
if (guard && iter > htid) {
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr0),
reinterpret_cast<int4*>(write_flt_s));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr1),
reinterpret_cast<int4*>(write_flt_s + BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr2),
reinterpret_cast<int4*>(write_flt_s + 2 * BK));
g2s_int4(
reinterpret_cast<const int4*>(g_filter_ptr3),
reinterpret_cast<int4*>(write_flt_s + 3 * BK));
} else {
*(reinterpret_cast<int4*>(write_flt_s)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + BK)) = make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 2 * BK)) =
make_int4(0, 0, 0, 0);
*(reinterpret_cast<int4*>(write_flt_s + 3 * BK)) =
make_int4(0, 0, 0, 0);
}
cp_async_fence();
}
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
stage++;
if (stage == stages) {
stage = 0;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
cp_async_wait(stages - 2);
}
if (!only_one_stage) {
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr =
get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
}
stage++;
if (stage == stages) {
stage = 0;
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
} else if (stage == stages - 1) {
read_src_s_0 += smem_switch_back;
read_src_s_1 += smem_switch_back;
read_flt_s_0 += smem_switch_back;
read_flt_s_1 += smem_switch_back;
} else {
read_src_s_0 += smem_switch;
read_src_s_1 += smem_switch;
read_flt_s_0 += smem_switch;
read_flt_s_1 += smem_switch;
}
cp_async_wait(0);
}
guard = iter < 0;
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s_0 + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[0][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s_0 + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[0][j] = make_int4(x, y, z, w);
}
// compute
#pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) {
int comp = (k_inner & 0x1);
int load = 1 - comp;
if (k_inner < BKd32 - 1 && !(k_inner == 1 && guard)) {
int32_t* read_src_s = (k_inner & 1) ? read_src_s_0 : read_src_s_1;
int32_t* read_flt_s = (k_inner & 1) ? read_flt_s_0 : read_flt_s_1;
read_src_s += 32 * ((k_inner + 1) >> 1);
read_flt_s += 32 * ((k_inner + 1) >> 1);
#pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_src_s + i * 4 * BK); // BK*32/8
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_src[load][i] = make_int4(x, y, z, w);
}
#pragma unroll
for (int j = 0; j < reg_md4; ++j) {
int x, y, z, w;
unsigned addr = get_smem_pointer(read_flt_s + 4 * j * BK);
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(addr));
reg_flt[load][j] = make_int4(x, y, z, w);
}
}
int* A = reinterpret_cast<int*>(&reg_flt[comp][0]);
int* B = reinterpret_cast<int*>(&reg_src[comp][0]);
#pragma unroll
for (int x = 0; x < reg_n; x++) {
#pragma unroll
for (int y = 0; y < reg_m; y++) {
int* D = reinterpret_cast<int*>(&reg_acc[y][x]);
int* C = reinterpret_cast<int*>(&reg_acc[y][x]);
asm volatile(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};\n"
: "=r"(D[0]), "=r"(D[1])
: "r"(B[y]), "r"(A[x]), "r"(C[0]), "r"(C[1]));
}
}
if (k_inner == 1 && guard) {
break;
}
}
__syncthreads();
/// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
int8_t* __restrict__ g_dst_ptr = dst + d_offset;
#pragma unroll
for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
nhw_post0 += 32;
nhw_post1 += 32;
nhw_post2 += 32;
nhw_post3 += 32;
}
#endif
}
} // namespace
namespace megdnn {
namespace cuda {
namespace ptx {
void run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params) {
#ifdef SM80_SUPPORTED
cudaFuncSetAttribute(
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu,
cudaFuncAttributeMaxDynamicSharedMemorySize, 49152);
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu<<<
grid, block, 49152, stream>>>(
*((int8_t**)params[0]), *((int8_t**)params[1]), *((float**)params[2]),
*((int8_t**)params[3]), *((float*)params[4]), *((float*)params[5]),
*((uint32_t*)params[6]), *((float*)params[7]), *((uint32_t*)params[8]),
*((Conv2dInt4Param*)params[9]), *((Conv2dConstantOffset*)params[10]));
#endif
}
} // namespace ptx
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include <cuda_runtime.h>
namespace megdnn {
namespace cuda {
namespace ptx {
void run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
void run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
void run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
void run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu(
const dim3 grid, const dim3 block, cudaStream_t stream, void** params);
} // namespace ptx
} // namespace cuda
} // namespace megdnn
#pragma once
//! ============= i2f ===============
__device__ __forceinline__ void i2f(int2& a) {
((float*)&a)[0] = static_cast<float>(a.x);
((float*)&a)[1] = static_cast<float>(a.y);
}
//! ============= mul ===============
template <typename T>
__device__ __forceinline__ void mul_v4(int4& c, const int4 a, const T alpha);
template <>
__device__ __forceinline__ void mul_v4<float>(
int4& c, const int4 a, const float alpha) {
((float*)&c)[0] = ((float*)&a)[0] * alpha;
((float*)&c)[1] = ((float*)&a)[1] * alpha;
((float*)&c)[2] = ((float*)&a)[2] * alpha;
((float*)&c)[3] = ((float*)&a)[3] * alpha;
}
//! ============= fma ===============
__device__ __forceinline__ void fma2(
int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha,
const int4 b) {
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c0)[0])
: "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c0)[1])
: "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[0])
: "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[1])
: "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3]));
}
__device__ __forceinline__ void fuse_z_1x8(
int4* a, const int& j, const int4& fuse_z, const float& gamma,
const int32_t& zero_point) {
const int2 z[2] = {
*reinterpret_cast<const int2*>(&fuse_z),
*(reinterpret_cast<const int2*>(&fuse_z) + 1)};
for (int k = 0; k < 4; k++) {
int f = ((z[0].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
f = ((z[0].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma;
f = ((z[1].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[2] += (f - zero_point) * gamma;
f = ((z[1].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[3] += (f - zero_point) * gamma;
}
for (int k = 0; k < 4; k++) {
int f = ((z[0].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
f = ((z[0].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;
f = ((z[1].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma;
f = ((z[1].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma;
}
}
__device__ __forceinline__ void fuse_z_1x8(
int2* a, const int& j, const int2& fuse_z, const float& gamma,
const int32_t& zero_point) {
#pragma unroll
for (int k = 0; k < 4; k++) {
int f = ((fuse_z.x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
f = ((fuse_z.x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma;
}
#pragma unroll
for (int k = 0; k < 4; k++) {
int f = ((fuse_z.y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
f = ((fuse_z.y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;
}
}
__device__ __forceinline__ void pack_f2i(
int& d0, int& d1, const int4 s0, const int4 s1, const int4 s2, const int4 s3,
const uint32_t relu, float& dst_zero_point) {
// uint32_t ix, iy, iz, iw;
uint32_t x0, y0, z0, w0;
uint32_t x1, y1, z1, w1;
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x0) : "f"(((float*)&s0)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y0) : "f"(((float*)&s0)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z0) : "f"(((float*)&s1)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w0) : "f"(((float*)&s1)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x1) : "f"(((float*)&s2)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y1) : "f"(((float*)&s2)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z1) : "f"(((float*)&s3)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w1) : "f"(((float*)&s3)[1]));
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(d0)
: "r"(x0), "r"(y0), "r"(z0), "r"(w0), "r"(x1), "r"(y1), "r"(z1), "r"(w1));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x0) : "f"(((float*)&s0)[2]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y0) : "f"(((float*)&s0)[3]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z0) : "f"(((float*)&s1)[2]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w0) : "f"(((float*)&s1)[3]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x1) : "f"(((float*)&s2)[2]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(y1) : "f"(((float*)&s2)[3]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(z1) : "f"(((float*)&s3)[2]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(w1) : "f"(((float*)&s3)[3]));
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(d1)
: "r"(x0), "r"(y0), "r"(z0), "r"(w0), "r"(x1), "r"(y1), "r"(z1), "r"(w1));
}
__device__ __forceinline__ void pack_f2i_with_relu(
int& d0, const int2 s0, const int2 s1, const int2 s2, const int2 s3,
const uint32_t relu, float& dst_zero_point) {
uint32_t x[8];
if (relu > 0) {
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[0]) : "f"(((float*)&s0)[0]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[1]) : "f"(((float*)&s0)[1]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[2]) : "f"(((float*)&s1)[0]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[3]) : "f"(((float*)&s1)[1]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[4]) : "f"(((float*)&s2)[0]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[5]) : "f"(((float*)&s2)[1]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[6]) : "f"(((float*)&s3)[0]));
asm volatile("cvt.rni.u8.f32 %0, %1;" : "=r"(x[7]) : "f"(((float*)&s3)[1]));
x[0] += dst_zero_point;
x[1] += dst_zero_point;
x[2] += dst_zero_point;
x[3] += dst_zero_point;
x[4] += dst_zero_point;
x[5] += dst_zero_point;
x[6] += dst_zero_point;
x[7] += dst_zero_point;
} else if (relu == 0) {
((float*)&s0)[0] += dst_zero_point;
((float*)&s0)[1] += dst_zero_point;
((float*)&s1)[0] += dst_zero_point;
((float*)&s1)[1] += dst_zero_point;
((float*)&s2)[0] += dst_zero_point;
((float*)&s2)[1] += dst_zero_point;
((float*)&s3)[0] += dst_zero_point;
((float*)&s3)[1] += dst_zero_point;
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[0]) : "f"(((float*)&s0)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[1]) : "f"(((float*)&s0)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[2]) : "f"(((float*)&s1)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[3]) : "f"(((float*)&s1)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[4]) : "f"(((float*)&s2)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[5]) : "f"(((float*)&s2)[1]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[6]) : "f"(((float*)&s3)[0]));
asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(x[7]) : "f"(((float*)&s3)[1]));
}
if (relu > 1) {
int r1, r2;
r1 = (x[0] >= relu);
x[0] *= r1;
r2 = (x[1] >= relu);
x[1] *= r2;
r1 = (x[2] >= relu);
x[2] *= r1;
r2 = (x[3] >= relu);
x[3] *= r2;
r1 = (x[4] >= relu);
x[4] *= r1;
r2 = (x[5] >= relu);
x[5] *= r2;
r1 = (x[6] >= relu);
x[6] *= r1;
r2 = (x[7] >= relu);
x[7] *= r2;
}
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(d0)
: "r"(x[0]), "r"(x[1]), "r"(x[2]), "r"(x[3]), "r"(x[4]), "r"(x[5]),
"r"(x[6]), "r"(x[7]));
}
#define I2F_1x8(a, i, j) \
i2f(a[i][j]); \
i2f(a[i][j + 1]); \
i2f(a[i][j + 2]); \
i2f(a[i][j + 3]); \
i2f(a[i][j + 4]); \
i2f(a[i][j + 5]); \
i2f(a[i][j + 6]); \
i2f(a[i][j + 7]);
#define I2F_4x8(a, i, j) \
I2F_1x8(a, i, j) I2F_1x8(a, i + 1, j) I2F_1x8(a, i + 2, j) I2F_1x8(a, i + 3, j)
#define FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
fma2(a[i][j], reg_acc[i][j], a[i][j + 1], reg_acc[i][j + 1], alpha, bias0); \
fma2(a[i][j + 2], reg_acc[i][j + 2], a[i][j + 3], reg_acc[i][j + 3], alpha, \
bias1); \
fma2(a[i][j + 4], reg_acc[i][j + 4], a[i][j + 5], reg_acc[i][j + 5], alpha, \
bias2); \
fma2(a[i][j + 6], reg_acc[i][j + 6], a[i][j + 7], reg_acc[i][j + 7], alpha, bias3);
#define FMA_4x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 1, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 2, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 3, j, alpha, bias0, bias1, bias2, bias3)
// pack 1x(8 int2) to int2
#define PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \
pack_f2i_with_relu( \
a[i][j].x, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3], relu, \
dst_zero_point); \
pack_f2i_with_relu( \
a[i][j].y, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7], relu, \
dst_zero_point);
// pack 4x8 int2 float to 4 int2
#define PACK_F2I_WITH_RELU_4x8(a, i, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 1, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 2, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 3, j, relu, dst_zero_point)
#define STG(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
d = g_dst_ptr + n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(d)) = *(reinterpret_cast<int2*>(&s)); \
}
#define STG_4x1(d, a, i, j) \
STG(d[0], a[i][j], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
STG(d[1], a[i + 1][j], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
STG(d[2], a[i + 2][j], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
STG(d[3], a[i + 3][j], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
// 1x8 1x(2x8 int2) to 2 int2
#define PACK_F2I_1x8(a, i, j) \
pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \
pack_f2i(a[i][j].y, a[i][j].w, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7]);
// 4x8 int4
#define PACK_F2I_4x8(a, i, j) \
PACK_F2I_1x8(a, i, j) PACK_F2I_1x8(a, i + 1, j) PACK_F2I_1x8(a, i + 2, j) \
PACK_F2I_1x8(a, i + 3, j)
#define LDG(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(&d)) = \
*(reinterpret_cast<const int2*>(g_z_ptr + s)); \
}
#define LDG_4x1(d, s, i) \
LDG(d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
LDG(d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
LDG(d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw;
#define COMPUTE_OFFSET_4x1(d, s, i) \
COMPUTE_OFFSET( \
d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
COMPUTE_OFFSET( \
d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
COMPUTE_OFFSET( \
d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \
stg_guard[i + 3])
#define STG_AFTER_LDG(d, s, g) \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(g_dst_ptr + d)) = *(reinterpret_cast<int2*>(&s)); \
}
#define STG_AFTER_LDG_4x1(d, a, i, j) \
STG_AFTER_LDG(d[i], a[i][j], stg_guard[i]) \
STG_AFTER_LDG(d[i + 1], a[i + 1][j], stg_guard[i + 1]) \
STG_AFTER_LDG(d[i + 2], a[i + 2][j], stg_guard[i + 2]) \
STG_AFTER_LDG(d[i + 3], a[i + 3][j], stg_guard[i + 3])
// vim: syntax=cpp.doxygen
#include <cuda_runtime.h>
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern "C" {
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly. Rather, they should use the
// cutlass::arch::ldsm<>() template.
//
__device__ uint32_t __nvvm_get_smem_pointer(void*);
}
#endif
inline __device__ unsigned get_smem_pointer(void* ptr) {
#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
// in inline PTX.
//
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only
// available in 10.2].
//
//__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTLASS helper to get SMEM pointer
return static_cast<unsigned>(__cvta_generic_to_shared(ptr));
#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && \
__CUDACC_VER_MINOR__ >= 2)
return __nvvm_get_smem_pointer(ptr);
#elif defined(__CUDA_ARCH__)
uint32_t smem_ptr;
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
"%0, smem_ptr; }\n"
: "=r"(smem_ptr)
: "l"(ptr));
return smem_ptr;
#else
return 0;
#endif
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册