提交 5f08b82f 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

fix(dnn/cuda): fix ptx mma algo compute bugs

GitOrigin-RevId: 19628d0c94e93ff1072db2eb04547e6f8db5f809
上级 d3e786ef
...@@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
// read fuse_z // read fuse_z
int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point),
...@@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
...@@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
......
...@@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
// read fuse_z // read fuse_z
int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point),
make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point),
...@@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }
/// output
if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
...@@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
/// output
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;
FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 1; y < reg_m; y += 1) {
I2F_4x8(reg_acc, y, 0); FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace
......
...@@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_fence(); cp_async_fence();
} }
bool only_one_stage = (stage == 1) ? true : false; bool only_one_stage = (stage == 1);
if (stage >= 2) { if (stage >= 2) {
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} else { } else {
...@@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} }
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
if (!only_one_stage) { if (!only_one_stage) {
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
...@@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }
if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
...@@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;
FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point);
PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 1; y < reg_m; y += 1) {
I2F_4x8(reg_acc, y, 0); FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point);
FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace
......
...@@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
guard = iter < 0; guard = iter < 0;
#pragma unroll #pragma unroll
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
...@@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t nhw_post3 = nhw_post0 + 24; size_t nhw_post3 = nhw_post0 + 24;
size_t stg_oc = bidy * BM + (warp_y << 6); size_t stg_oc = bidy * BM + (warp_y << 6);
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
...@@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_4x1(stg_ptr, reg_acc, y, 0); STG_4x1(stg_ptr, reg_acc, y, 0);
......
...@@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
} }
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
guard = iter < 0; guard = iter < 0;
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
...@@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t nhw_post3 = nhw_post0 + 24; size_t nhw_post3 = nhw_post0 + 24;
size_t stg_oc = bidy * BM; size_t stg_oc = bidy * BM;
size_t oc = bidy * BM + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) { if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta); mul_v4(load_bias2, load_bias2, beta);
...@@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_4x1(stg_ptr, reg_acc, y, 0); STG_4x1(stg_ptr, reg_acc, y, 0);
......
...@@ -449,7 +449,7 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -449,7 +449,7 @@ extern "C" __global__ void __launch_bounds__(256)
bool stg_guard[8]; bool stg_guard[8];
#pragma unroll #pragma unroll
for (int y = 0; y < reg_m; y += 4) { for (int y = 0; y < reg_m; y += 4) {
COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y) COMPUTE_OFFSET_4x1(g_offset, y);
nhw_post0 += 32; nhw_post0 += 32;
nhw_post1 += 32; nhw_post1 += 32;
...@@ -457,7 +457,7 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -457,7 +457,7 @@ extern "C" __global__ void __launch_bounds__(256)
nhw_post3 += 32; nhw_post3 += 32;
} }
bool only_one_stage = (stage == 1) ? true : false; bool only_one_stage = (stage == 1);
if (stage >= 2) { if (stage >= 2) {
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} else { } else {
...@@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait(stages - 2); cp_async_wait(stages - 2);
} }
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
}
if (!only_one_stage) { if (!only_one_stage) {
#pragma unroll // low #pragma unroll // low
for (int i = 0; i < reg_nd4; ++i) { for (int i = 0; i < reg_nd4; ++i) {
...@@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt[0][j] = make_int4(x, y, z, w); reg_flt[0][j] = make_int4(x, y, z, w);
} }
if (oc < param.oc) {
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
// compute // compute
#pragma unroll #pragma unroll
for (int k_inner = 0; k_inner < BKd32; k_inner++) { for (int k_inner = 0; k_inner < BKd32; k_inner++) {
...@@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256) ...@@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads(); __syncthreads();
/// output /// output
size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad;
const float* bias_ptr = bias + oc;
int4 load_bias0 = make_int4(0, 0, 0, 0);
int4 load_bias1 = make_int4(0, 0, 0, 0);
int4 load_bias2 = make_int4(0, 0, 0, 0);
int4 load_bias3 = make_int4(0, 0, 0, 0);
if (oc < param.oc) {
load_bias0 = *(reinterpret_cast<const int4*>(bias_ptr));
load_bias1 = *(reinterpret_cast<const int4*>(bias_ptr + 4));
load_bias2 = *(reinterpret_cast<const int4*>(bias_ptr + 8));
load_bias3 = *(reinterpret_cast<const int4*>(bias_ptr + 12));
mul_v4(load_bias0, load_bias0, beta);
mul_v4(load_bias1, load_bias1, beta);
mul_v4(load_bias2, load_bias2, beta);
mul_v4(load_bias3, load_bias3, beta);
}
int8_t* __restrict__ g_dst_ptr = dst + d_offset; int8_t* __restrict__ g_dst_ptr = dst + d_offset;
#pragma unroll FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
for (int y = 0; y < reg_m; y += 4) { PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point);
I2F_4x8(reg_acc, y, 0);
FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0);
nhw_post0 += 32; #pragma unroll
nhw_post1 += 32; for (int y = 1; y < reg_m; y += 1) {
nhw_post2 += 32; FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3);
nhw_post3 += 32; PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point);
STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]);
} }
STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]);
#endif #endif
} }
} // namespace } // namespace
......
...@@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4<float>( ...@@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4<float>(
__device__ __forceinline__ void fma2( __device__ __forceinline__ void fma2(
int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha,
const int4 b) { const int4 b) {
asm("fma.rz.f32 %0, %1, %2, %3;" ((float*)&c0)[0] = a0.x * alpha + ((float*)&b)[0];
: "=f"(((float*)&c0)[0]) ((float*)&c0)[1] = a0.y * alpha + ((float*)&b)[1];
: "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0])); ((float*)&c1)[0] = a1.x * alpha + ((float*)&b)[2];
asm("fma.rz.f32 %0, %1, %2, %3;" ((float*)&c1)[1] = a1.y * alpha + ((float*)&b)[3];
: "=f"(((float*)&c0)[1])
: "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[0])
: "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2]));
asm("fma.rz.f32 %0, %1, %2, %3;"
: "=f"(((float*)&c1)[1])
: "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3]));
}
__device__ __forceinline__ void fuse_z_1x8(
int4* a, const int& j, const int4& fuse_z, const float& gamma,
const int32_t& zero_point) {
const int2 z[2] = {
*reinterpret_cast<const int2*>(&fuse_z),
*(reinterpret_cast<const int2*>(&fuse_z) + 1)};
for (int k = 0; k < 4; k++) {
int f = ((z[0].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
f = ((z[0].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma;
f = ((z[1].x >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[2] += (f - zero_point) * gamma;
f = ((z[1].x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k]))[3] += (f - zero_point) * gamma;
}
for (int k = 0; k < 4; k++) {
int f = ((z[0].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
f = ((z[0].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;
f = ((z[1].y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma;
f = ((z[1].y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma;
}
} }
__device__ __forceinline__ void fuse_z_1x8( __device__ __forceinline__ void fuse_z_1x8(
int2* a, const int& j, const int2& fuse_z, const float& gamma, int2* a, const int& j, const int2& fuse_z, const float& gamma,
const int32_t& zero_point) { const int32_t& zero_point) {
float x = zero_point * gamma;
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
int f = ((fuse_z.x >> (k * 8)) & 15); int f = ((fuse_z.x >> (k * 8)) & 15);
f = (f << 28) >> 28; ((float*)&(a[j + k]))[0] += f * gamma - x;
((float*)&(a[j + k]))[0] += (f - zero_point) * gamma;
f = ((fuse_z.x >> (k * 8 + 4)) & 15); f = ((fuse_z.x >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28; ((float*)&(a[j + k]))[1] += f * gamma - x;
((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; f = ((fuse_z.y >> (k * 8)) & 15);
} ((float*)&(a[j + k + 4]))[0] += f * gamma - x;
#pragma unroll
for (int k = 0; k < 4; k++) {
int f = ((fuse_z.y >> (k * 8)) & 15);
f = (f << 28) >> 28;
((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma;
f = ((fuse_z.y >> (k * 8 + 4)) & 15); f = ((fuse_z.y >> (k * 8 + 4)) & 15);
f = (f << 28) >> 28; ((float*)&(a[j + k + 4]))[1] += f * gamma - x;
((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma;
} }
} }
...@@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu( ...@@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu(
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
// 1x8 1x(2x8 int2) to 2 int2 // 1x8 1x(2x8 int2) to 2 int2
#define PACK_F2I_1x8(a, i, j) \ #define PACK_F2I_1x8(a, i, j) \
pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \
...@@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu( ...@@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu(
stg_guard[i + 2]) \ stg_guard[i + 2]) \
LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \ #define COMPUTE_OFFSET(s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \ n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \ hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; g = nhw_post##idx < param.nhw;
#define COMPUTE_OFFSET_4x1(d, s, i) \ #define COMPUTE_OFFSET_4x1(s, i) \
COMPUTE_OFFSET(s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \ COMPUTE_OFFSET( \
d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, stg_guard[i + 1]) \
COMPUTE_OFFSET( \ COMPUTE_OFFSET( \
d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, stg_guard[i + 2]) \
stg_guard[i + 1]) \
COMPUTE_OFFSET( \
d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
COMPUTE_OFFSET( \ COMPUTE_OFFSET( \
d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \ s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
stg_guard[i + 3])
#define STG_AFTER_LDG(d, s, g) \ #define STG_AFTER_LDG(d, s, g) \
if (stg_oc < param.oc && g) { \ if (stg_oc < param.oc && g) { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册