提交 64551105 编写于 作者: M Megvii Engine Team

feat(cuda): add convbias ptx algo

GitOrigin-RevId: 08e9f666412f568f76a806bdfe3978f1aa60f9bf
上级 8395a459
......@@ -92,6 +92,12 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
fill_dwconv_algos();
all_algos.push_back(&int8_chwn4_dotprod);
all_algos.push_back(&fallback_nchw_qs8);
fill_ptx_algos();
for (auto&& algo : algo_ptx_conv2d_u4_s4) {
all_algos.push_back(&algo);
}
for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
non_cudnn_algos.push_back(all_algos[i]);
}
......@@ -364,6 +370,15 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2});
}
void ConvBiasForwardImpl::AlgoPack::fill_ptx_algos() {
algo_ptx_conv2d_u4_s4.emplace_back(
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 256, 256});
algo_ptx_conv2d_u4_s4.emplace_back(
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 128, 128});
algo_ptx_conv2d_u4_s4.emplace_back(
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{256, 64, 128});
}
ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum(
cudnnConvolutionFwdAlgo_t algo) {
for (auto&& i : cudnn_convs) {
......
......@@ -78,6 +78,7 @@ public:
CUDA_SIMPLE_INT1,
CUDA_CUDNN_CONV_V8,
CUDA_CUDNN_CONVBIAS_V8,
CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -1203,6 +1204,45 @@ private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};
class ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm final
: public AlgoBase {
public:
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm(
unsigned int tile_nhw, unsigned int tile_oc, unsigned int threads)
: m_tile_nhw{tile_nhw}, m_tile_oc{tile_oc}, m_threads{threads} {
m_name = ConvBias::algo_name<ConvBias::DirectParam>(
ssprintf(
"PTX_UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%uX%u_%u", m_tile_nhw,
m_tile_oc, m_threads),
ConvBias::DirectParam{});
}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4)
std::string param() const override {
std::string ret;
serialize_write_pod(m_tile_nhw, ret);
serialize_write_pod(m_tile_oc, ret);
serialize_write_pod(m_threads, ret);
return ret;
}
private:
std::string kernel_key(const SizeArgs& args) const;
unsigned int m_tile_nhw, m_tile_oc, m_threads;
std::string m_name;
void reorder_filter_bias(
const ExecArgs& args, void* reduce_filter, void* reordered_filter,
void* reordered_bias) const;
};
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
......@@ -1251,6 +1291,7 @@ public:
AlgoCUDNNConvV8 cudnn_conv_v8;
AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8;
#endif
std::vector<AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm> algo_ptx_conv2d_u4_s4;
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);
......@@ -1265,6 +1306,7 @@ private:
void fill_cudnn_algos();
void fill_dp4a_algos();
void fill_dwconv_algos();
void fill_ptx_algos();
};
} // namespace cuda
......
......@@ -72,6 +72,7 @@ public:
class AlgoCUDNNConvV8;
class AlgoCUDNNConvBiasActivationV8;
#endif
class AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm;
class AlgoPack;
......
#include "src/cuda/conv_bias/ptx_helper.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/query_blocksize.cuh"
using namespace megdnn;
using namespace cuda;
using namespace ptx;
namespace {
template <uint32_t size_bits, uint32_t interleaved>
__device__ __forceinline__ void reorder_imma_filter_func(
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH,
uint32_t FW, uint32_t lane) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
uint32_t elements = lane * elements_per_lane;
uint32_t row = elements / (IC * FH * FW);
uint32_t col = elements - row * IC * FH * FW;
uint32_t sec = row / 4;
uint32_t res = col & (interleaved - 1);
uint32_t sec_sec = row & 3;
uint32_t sec_res = (row & 15) / 4;
uint32_t crosswise_offset = ((sec_sec >> 1) * 2 * interleaved) +
(((sec_sec & 1) ^ (sec_res >> 1)) * interleaved);
uint32_t residue_offset =
((res / elements_per_lane) ^ (sec_res & 1)) * elements_per_lane;
uint32_t dst_offset =
(sec / 2) * 8 * FH * FW * IC + (col / interleaved) * (8 * interleaved) +
(sec & 1) * (4 * interleaved) + crosswise_offset + residue_offset;
static constexpr uint32_t instruction_shape_col = 8;
// 4 threads per Quad
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4;
// 4 threads per Quad
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4;
uint32_t elem_in_interleaved = row % interleaved;
uint32_t elem_in_interleaved_pack = elem_in_interleaved / elements_per_thread;
int elem_new = (row / interleaved * interleaved +
elem_in_interleaved_pack % 4 * reordered_elements_per_thread +
elem_in_interleaved_pack / 4 * elements_per_thread +
elem_in_interleaved % elements_per_thread) *
(IC * FH * FW) +
col;
*(reinterpret_cast<int4*>(dst + (dst_offset * size_bits / 8))) =
*(reinterpret_cast<const int4*>(src + (elem_new * size_bits / 8)));
}
template <uint32_t interleaved>
__device__ __forceinline__ void reorder_imma_bias_func(
float* __restrict__ dst, float src_value, uint32_t OC, uint32_t lane) {
dst[lane] = src_value;
}
template <uint32_t size_bits, uint32_t interleaved>
__global__ void reorder_imma_filter_bias_kernel(
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias,
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias,
float bias_scale, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane;
const uint32_t size2 = OC;
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x;
if (lane < size1) {
reorder_imma_filter_func<size_bits, interleaved>(
dst_filter, src_filter, OC, IC, FH, FW, lane);
} else if (lane < size1 + size2) {
lane = lane - size1;
float src_bias_value = src_bias[lane] * bias_scale;
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane);
}
}
template <uint32_t size_bits, uint32_t interleaved>
__global__ void reorder_imma_filter_bias_fusion_zero_point_kernel(
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias,
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias,
float bias_scale, const int32_t* reduce_filter, float zero_point, uint32_t OC,
uint32_t IC, uint32_t FH, uint32_t FW) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane;
const uint32_t size2 = OC;
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x;
if (lane < size1) {
reorder_imma_filter_func<size_bits, interleaved>(
dst_filter, src_filter, OC, IC, FH, FW, lane);
} else if (lane < size1 + size2) {
lane = lane - size1;
// fusion bias and zero_point
// zero_point = zero_point * src_scale * filter_scale
float src_bias_value =
src_bias[lane] * bias_scale - reduce_filter[lane] * zero_point;
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane);
}
}
} // namespace
template <uint32_t size_bits, uint32_t interleaved>
void megdnn::cuda::ptx::reorder_imma_filter_bias(
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter,
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC,
uint32_t FH, uint32_t FW, cudaStream_t stream) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>(
reorder_imma_filter_bias_kernel<size_bits, interleaved>));
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC;
nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
reorder_imma_filter_bias_kernel<size_bits, interleaved>
<<<nr_blocks, nr_threads, 0, stream>>>(
dst_filter, dst_bias, src_filter, src_bias, bias_scale, OC, IC, FH,
FW);
after_kernel_launch();
}
template <uint32_t size_bits, uint32_t interleaved>
void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point(
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter,
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter,
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW,
cudaStream_t stream) {
static constexpr uint32_t elements_per_lane = 128 / size_bits;
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>(
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved>));
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC;
nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved>
<<<nr_blocks, nr_threads, 0, stream>>>(
dst_filter, dst_bias, src_filter, src_bias, bias_scale,
reduce_filter, zero_point, OC, IC, FH, FW);
after_kernel_launch();
}
#define INST(_size_bits, _interleaved) \
template void \
megdnn::cuda::ptx::reorder_imma_filter_bias<_size_bits, _interleaved>( \
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, \
uint32_t FH, uint32_t FW, cudaStream_t stream);
INST(8, 32)
INST(4, 64)
#undef INST
#define INST(_size_bits, _interleaved) \
template void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point< \
_size_bits, _interleaved>( \
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, \
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, \
cudaStream_t stream);
INST(4, 64)
#undef INST
// vim: syntax=cuda.doxygen
#pragma once
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace ptx {
struct Conv2dInt8Param {
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow;
uint32_t ibs, ics, ihs;
uint32_t obs, ocs, ohs;
uint32_t icfhfw;
uint32_t nhw;
uint32_t oc32;
Uint32Fastdiv div_ohow;
Uint32Fastdiv div_ow;
Conv2dInt8Param(
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw,
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc,
uint32_t oh, uint32_t ow, uint32_t interleaved)
: n(n),
ic(ic),
ih(ih),
iw(iw),
fh(fh),
fw(fw),
sh(sh),
sw(sw),
ph(ph),
pw(pw),
oc(oc),
oh(oh),
ow(ow) {
ibs = ic * ih * iw;
ics = ih * iw * interleaved;
ihs = iw * interleaved;
obs = oc * oh * ow;
ocs = oh * ow * interleaved;
ohs = ow * interleaved;
icfhfw = ic * fh * fw;
div_ohow = oh * ow;
div_ow = ow;
nhw = n * oh * ow;
// used for dp4a kernel, reduce usage of register file
oc32 = oc * 32;
}
};
struct Conv2dInt4Param {
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow;
uint32_t ibs, ics, ihs;
uint32_t obs, ocs, ohs;
uint32_t icfhfw;
uint32_t nhw;
Uint32Fastdiv div_ohow;
Uint32Fastdiv div_ow;
Conv2dInt4Param(
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw,
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc,
uint32_t oh, uint32_t ow, uint32_t interleaved = 64)
: n(n),
ic(ic),
ih(ih),
iw(iw),
fh(fh),
fw(fw),
sh(sh),
sw(sw),
ph(ph),
pw(pw),
oc(oc),
oh(oh),
ow(ow) {
constexpr uint32_t size_bits = 4;
// all stride size in bytes
ibs = ic * ih * iw * size_bits / 8;
ics = ih * iw * interleaved * size_bits / 8;
ihs = iw * interleaved * size_bits / 8;
obs = oc * oh * ow * size_bits / 8;
ocs = oh * ow * interleaved * size_bits / 8;
ohs = ow * interleaved * size_bits / 8;
icfhfw = ic * fh * fw;
nhw = n * oh * ow;
div_ohow = oh * ow;
div_ow = ow;
}
};
struct Conv2dConstantOffsetParam {
int32_t begin;
int32_t size;
int32_t max;
int32_t rewind;
};
#define CONSTANT_BUFFER_SIZE 848
struct Conv2dConstantOffset {
Conv2dConstantOffsetParam c_offset_param;
int c_offset[CONSTANT_BUFFER_SIZE];
};
template <uint32_t size_bits, uint32_t interleaved>
void reorder_imma_filter_bias(
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter,
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC,
uint32_t FH, uint32_t FW, cudaStream_t stream);
template <uint32_t size_bits, uint32_t interleaved>
void reorder_imma_filter_bias_fusion_zero_point(
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter,
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter,
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW,
cudaStream_t stream);
} // namespace ptx
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
/**
* \file dnn/src/cuda/conv_bias/ptx_implicit_gemm_uint4_int4_nchw64_imma.cpp
*/
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/ptx_helper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/ptx/uint4_int4/kern.cuh"
#include "src/cuda/ptx_loader.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace ptx;
namespace {
// all stride are in bytes
void compute_conv2d_offset(
size_t fh, size_t fw, size_t ics, size_t ihs,
Conv2dConstantOffset& constant_offset) {
constexpr int interleaved = 64;
constexpr int size_bits = 4;
constexpr int threablock_k = 128;
constexpr int inc_step = threablock_k / interleaved;
size_t i = 0;
int* s32 = &(constant_offset.c_offset[0]);
for (; i < inc_step; i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
}
for (; i < (inc_step + fh * fw * inc_step); i++) {
int c = i / (fh * fw);
int khkw = i % (fh * fw);
int kh = khkw / fw;
int kw = khkw % fw;
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8;
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1]));
s8[0] = kh;
s8[1] = kw;
s8[2] = -kh;
s8[3] = -kw;
int i_ = i - inc_step;
c = i_ / (fh * fw);
khkw = i_ % (fh * fw);
kh = khkw / fw;
kw = khkw % fw;
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8;
}
}
}; // namespace
std::string ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::kernel_key(
const SizeArgs& args) const {
std::string kernel_key;
using NonlineMode = Param::NonlineMode;
auto&& param = args.opr->param();
if (args.z_layout->ndim > 0) {
kernel_key = ssprintf(
"%s_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
} else {
kernel_key = ssprintf(
"%s_conv_bias_uint4_int4_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc);
}
megdnn_assert(
param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::IDENTITY);
kernel_key += "_relu";
return kernel_key;
}
bool ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
if (args.bias_layout->ndim <= 0)
return false;
using Param = param::ConvBias;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
using NonlineMode = Param::NonlineMode;
bool available = true;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW64)
return false;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param);
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
available &= param.mode == Mode::CROSS_CORRELATION;
// nonlineMode must be RELU or IDENTITY
available &=
(param.nonlineMode == NonlineMode::RELU ||
param.nonlineMode == NonlineMode::IDENTITY);
// check data type
auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype,
bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype;
available &=
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm &&
filter_dtype.enumv() == DTypeEnum::QuantizedS4 &&
bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::Quantized4Asymm);
// TODO: support dialtion
available &= dh == 1 && dw == 1;
// ensure precomputed offsets are positive integers
available &= hi >= fh && wi >= fw;
// only support sm_86 or later, platform should have tensorcore int4
// support
available &=
(is_compute_capability_equalto(8, 0) ||
is_compute_capability_equalto(8, 6));
// param buffer size is 4K, use 3K to store precomputed offset
size_t kMaxFilterPixels = CONSTANT_BUFFER_SIZE / (2 * 128 / 64) - 1;
available &= fh * fw <= kMaxFilterPixels;
return available;
}
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::
get_workspace_in_bytes(const SizeArgs& args) const {
if (args.preprocessed_filter == nullptr) {
size_t OC = args.filter_layout->operator[](0),
IC = args.filter_layout->operator[](1) * 64,
FH = args.filter_layout->operator[](2),
FW = args.filter_layout->operator[](3);
size_t ws_size_reduce_filter = OC * sizeof(int32_t);
// for reduce filter
{
size_t A = OC, B = IC * FH * FW / 8, C = 1;
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C);
}
return args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte() + ws_size_reduce_filter;
}
return 0_z;
}
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param);
auto&& stream = cuda_stream(args.opr->handle());
constexpr int interleaved = 64;
void* bias_ptr = nullptr;
void* filter_ptr = nullptr;
if (args.preprocessed_filter) {
megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr();
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr();
} else {
// reorder filter and bias
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr);
bias_ptr = reinterpret_cast<void*>(
args.workspace.raw_ptr + args.filter_layout->span().dist_byte());
void* reduce_filter_ptr = reinterpret_cast<void*>(
args.workspace.raw_ptr + args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte());
reorder_filter_bias(args, reduce_filter_ptr, filter_ptr, bias_ptr);
}
uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, u32_fw = fw,
u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw, u32_co = co,
u32_ho = ho, u32_wo = wo;
Conv2dInt4Param kern_param(
u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw, u32_sh, u32_sw, u32_ph,
u32_pw, u32_co, u32_ho, u32_wo, interleaved);
Conv2dConstantOffset kern_coffset;
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset);
// begin is not need
kern_coffset.c_offset_param.begin = param_buffer_start_address();
kern_coffset.c_offset_param.size = 4 * (1 + fh * fw);
kern_coffset.c_offset_param.max = 4 * fh * fw;
kern_coffset.c_offset_param.rewind = 4 * (1 - fh * fw);
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale,
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale;
uint32_t src_zero_point =
(uint32_t)(args.src_layout->dtype.param<dtype::Quantized4Asymm>()
.zero_point);
uint32_t pk_src_zero_point = 0;
for (int i = 0; i < 8; i++) {
pk_src_zero_point <<= 4;
pk_src_zero_point |= (src_zero_point & 0xF);
}
float dst_zero_point =
(float)(args.dst_layout->dtype.param<dtype::Quantized4Asymm>().zero_point);
float alpha = src_scale * filter_scale / dst_scale, beta = 1.f;
unsigned int tx = m_threads, ty = 1;
unsigned int gridx =
div_ceil<unsigned int>(static_cast<unsigned int>(n * ho * wo), m_tile_nhw);
unsigned int gridy =
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc);
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr());
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr());
using NonlineMode = Param::NonlineMode;
auto kern_key = kernel_key(args);
auto&& kernel = PTXKernelLoader::instance().get_kernel(kern_key);
if (args.z_layout->ndim > 0) {
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr());
auto z_param = args.z_layout->dtype.param<dtype::Quantized4Asymm>();
int32_t z_zero_point = (int32_t)z_param.zero_point;
float z_scale = z_param.scale;
float gamma = z_scale / dst_scale;
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr,
&dst_ptr, &alpha, &beta, &gamma};
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) +
sizeof(bias_ptr) + sizeof(z_ptr) +
sizeof(dst_ptr) + sizeof(alpha) +
sizeof(beta) + sizeof(gamma);
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point);
params.push_back(&pk_src_zero_point);
kern_coffset.c_offset_param.begin += sizeof(z_zero_point);
params.push_back(&z_zero_point);
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point);
params.push_back(&dst_zero_point);
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param);
params.push_back(&kern_coffset);
dim3 grid(gridx, gridy, 1);
dim3 block(tx, ty, 1);
kernel(grid, block, stream, params.data());
} else {
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr,
&dst_ptr, &alpha, &beta};
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) +
sizeof(bias_ptr) + sizeof(dst_ptr) +
sizeof(alpha) + sizeof(beta);
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point);
params.push_back(&pk_src_zero_point);
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point);
params.push_back(&dst_zero_point);
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0;
params.push_back(&relu);
kern_coffset.c_offset_param.begin += sizeof(relu);
params.push_back(&kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_param);
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param);
params.push_back(&kern_coffset);
dim3 grid(gridx, gridy, 1);
dim3 block(tx, ty, 1);
kernel(grid, block, stream, params.data());
}
after_kernel_launch();
}
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
size_t OC = args.filter_layout->operator[](0),
IC = args.filter_layout->operator[](1) * 64,
FH = args.filter_layout->operator[](2),
FW = args.filter_layout->operator[](3);
size_t ws_size_reduce_filter = OC * sizeof(int32_t);
// for reduce filter
{
size_t A = OC, B = IC * FH * FW / 8, C = 1;
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C);
}
return ws_size_reduce_filter;
}
SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::
deduce_preprocessed_filter_layout(const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous(),
args.bias_layout->collapse_contiguous()};
}
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::reorder_filter_bias(
const ExecArgs& args, void* reduce_filter, void* reordered_filter,
void* reordered_bias) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param);
auto&& stream = cuda_stream(args.opr->handle());
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale,
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale,
bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale;
float scaled_src_zero_point =
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point *
src_scale * filter_scale / dst_scale;
// NCHW64 reduce CHW64
do_dispatch_reduce_with_scale_filter_4bit<true>(
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), 1, co,
ci * fh * fw / 8, static_cast<int32_t*>(reduce_filter), stream);
reorder_imma_filter_bias_fusion_zero_point<4, 64>(
reinterpret_cast<int8_t*>(reordered_filter),
reinterpret_cast<float*>(reordered_bias),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()),
args.bias_tensor->compatible_ptr<int32_t>(), bias_scale / dst_scale,
static_cast<int32_t*>(reduce_filter), scaled_src_zero_point, co, ci, fh, fw,
stream);
}
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess(
const ExecArgs& args) const {
reorder_filter_bias(
args, args.workspace.raw_ptr,
args.preprocessed_filter->tensors[0].raw_ptr(),
args.preprocessed_filter->tensors[1].raw_ptr());
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/ptx_loader.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/ptx_loader.h"
using namespace megdnn;
using namespace cuda;
// ******************* PTXKernelLoader *********************
const std::unordered_map<std::string, PTXKernelLoader::kernel> PTXKernelLoader::KERNEL_MAP =
{{"ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu",
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu},
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x128_relu",
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu},
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu",
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu},
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu",
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu},
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x128_relu",
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu},
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu",
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu}};
PTXKernelLoader& PTXKernelLoader::instance() {
static PTXKernelLoader ins;
return ins;
}
const PTXKernelLoader::kernel PTXKernelLoader::get_kernel(
const std::string& kernel_name) {
decltype(KERNEL_MAP.begin()) kernel_iter;
kernel_iter = KERNEL_MAP.find(kernel_name);
megdnn_throw_if(
kernel_iter == KERNEL_MAP.end(), megdnn_error,
ssprintf("kernel name %s not found in KERNEL_MAP", kernel_name.c_str())
.c_str());
return kernel_iter->second;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file dnn/src/cuda/ptx_loader.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <mutex>
#include <unordered_map>
#include "src/cuda/ptx/uint4_int4/kern.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class PTXKernelLoader {
private:
PTXKernelLoader() = default;
using kernel = std::function<void(const dim3, const dim3, cudaStream_t, void**)>;
public:
PTXKernelLoader(const PTXKernelLoader&) = delete;
const PTXKernelLoader& operator=(const PTXKernelLoader&) = delete;
static PTXKernelLoader& instance();
const kernel get_kernel(const std::string& kernel_name);
static const std::unordered_map<std::string, kernel> KERNEL_MAP;
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册