提交 96d90be1 编写于 作者: M Megvii Engine Team

feat(dnn): fallback support int4 relayout

GitOrigin-RevId: 3625f5847055940e646358654f296922f05afa93
上级 eef0308b
......@@ -625,6 +625,7 @@ struct log<1> {
::megdnn::dtype::log<sizeof(ctype)>::value; \
static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name; \
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \
static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits; \
static MEGDNN_CONSTEXPR bool has_param = _has_param
#else
#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \
......@@ -632,7 +633,8 @@ struct log<1> {
typedef ::megdnn::dtype::_name dtype; \
static const uint16_t size_log = ::megdnn::dtype::log<sizeof(ctype)>::value; \
static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name; \
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \
static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits;
#endif // MEGDNN_CC_HOST
#define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \
......
......@@ -8,12 +8,129 @@
using namespace megdnn;
using namespace fallback;
namespace megdnn {
namespace relayout {
namespace transpose_fallback {
template <>
struct transpose_traits<dt_qint4> {
static constexpr size_t block_size = BLOCK_LINE_SIZE_BYTES;
};
template <>
void transpose_block_fallback<dt_qint4>(
const dt_qint4* src, dt_qint4* dst, const size_t src_stride,
const size_t dst_stride, size_t block_h, size_t block_w) {
constexpr size_t block_size = transpose_traits<dt_qint4>::block_size;
uint8_t block[block_size][block_size];
uint8_t* src_ptr = (uint8_t*)src;
uint8_t* dst_ptr = (uint8_t*)dst;
for (size_t i = 0; i < block_h; ++i) {
size_t src_offset_base = i * src_stride;
for (size_t j = 0; j < block_w; ++j) {
size_t src_offset = src_offset_base + j;
size_t src_byte_offset = src_offset >> 1;
if (src_offset % 2 == 0) {
block[j][i] = src_ptr[src_byte_offset] & 0xf;
} else {
block[j][i] = ((src_ptr[src_byte_offset] & 0xf0) >> 4) & 0xf;
}
}
}
for (size_t i = 0; i < block_w; ++i) {
size_t dst_offset_base = i * dst_stride;
for (size_t j = 0; j < block_h; ++j) {
size_t dst_offset = dst_offset_base + j;
size_t dst_byte_offset = dst_offset >> 1;
uint8_t dst_temp = dst_ptr[dst_byte_offset];
uint8_t src_temp = block[i][j];
if (dst_offset % 2 == 0) {
dst_temp = (dst_temp & 0xf0) | src_temp;
} else {
dst_temp = (dst_temp & 0xf) | (src_temp << 4);
}
dst_ptr[dst_byte_offset] = dst_temp;
}
}
}
template <>
void transpose<dt_qint4>(
size_t batch, size_t m, size_t n, dt_qint4* src, dt_qint4* dst,
size_t stride_m) {
if (stride_m == 0) {
stride_m = n;
}
uint8_t* batch_src = (uint8_t*)(src);
uint8_t* batch_dst = (uint8_t*)(dst);
constexpr size_t B = transpose_traits<dt_qint4>::block_size;
auto work_block = [m, stride_m, &batch_src, &batch_dst](
const size_t i, const size_t j, const size_t h,
const size_t w) {
size_t src_offset = i * stride_m + j;
size_t dst_offset = j * m + i;
megdnn_assert(src_offset % 2 == 0 && dst_offset % 2 == 0);
auto src = batch_src + (src_offset >> 1);
auto dst = batch_dst + (dst_offset >> 1);
MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) {
if (h == B && w == B) {
transpose_block((dt_qint4*)src, (dt_qint4*)dst, stride_m, m);
} else {
transpose_block((dt_qint4*)src, (dt_qint4*)dst, stride_m, m, h, w);
}
}
MIDOUT_END();
};
auto work_row = [&work_block, n](size_t i, size_t h) {
size_t j = 0;
for (; j + B <= n; j += B) {
work_block(i, j, h, B);
}
if (j < n) {
work_block(i, j, h, n - j);
}
};
for (size_t b = 0; b < batch; ++b) {
size_t i = 0;
for (; i + B <= m; i += B) {
work_row(i, B);
}
if (i < m) {
work_row(i, m - i);
}
size_t src_offset = m * stride_m;
size_t dst_offset = m * n;
megdnn_assert(src_offset % 2 == 0 && dst_offset % 2 == 0);
batch_src += (src_offset >> 1);
batch_dst += (dst_offset >> 1);
}
}
} // namespace transpose_fallback
} // namespace relayout
} // namespace megdnn
namespace {
bool is_lastdim_contig(const TensorLayout& layout) {
return layout.ndim <= 3 && layout.stride[layout.ndim - 1] == 1;
}
bool is_int4(const TensorLayout& layout) {
return layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
layout.dtype.enumv() == DTypeEnum::Quantized4Asymm;
}
inline bool check_dtype_support_transparam(
bool trans, bool is_bit4, const relayout::TransposeParam& param) {
if (trans && is_bit4) {
auto c = param.c;
return c == 1 || c == 2 || c == 4 || c == 8;
}
return trans;
}
template <size_t sz, typename T0 = char>
struct equiv_ctype_storage {
T0 _[sz];
......@@ -26,16 +143,111 @@ struct equiv_ctype {
alignof(typename DTypeTrait<dtype>::ctype)>;
};
typedef void (*memcpy_policy_t)(void* cont, void* non_cont, size_t);
typedef void (*memcpy_policy_t)(
void* cont, void* non_cont, size_t src_offset, size_t dst_offset, size_t size);
void memcpy_cont2noncont(void* cont, void* non_cont, size_t size) {
void memcpy_cont2noncont(void* cont, void* non_cont, size_t, size_t, size_t size) {
memcpy(non_cont, cont, size);
}
void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) {
void memcpy_noncont2cont(void* cont, void* non_cont, size_t, size_t, size_t size) {
memcpy(cont, non_cont, size);
}
void memcpy_4bit(
void* cont, void* nocont, size_t cont_offset, size_t nocont_offset,
size_t size) {
if (size == 0)
return;
uint8_t* cont_u8 = (uint8_t*)cont;
uint8_t* nocont_u8 = (uint8_t*)nocont;
size_t cont_bytes = cont_offset >> 1;
size_t nocont_bytes = nocont_offset >> 1;
size_t size_byte = size >> 1;
void* cont_ptr = cont_u8 + cont_bytes;
void* nocont_ptr = nocont_u8 + nocont_bytes;
bool size_align = size % 2 == 0;
bool cont_align = cont_offset % 2 == 0;
bool nocont_align = nocont_offset % 2 == 0;
if (cont_align && nocont_align) {
memcpy(cont_ptr, nocont_ptr, size_byte);
if (!size_align) {
uint8_t* dst_ptr = (uint8_t*)cont_ptr + size_byte;
uint8_t* src_ptr = (uint8_t*)nocont_ptr + size_byte;
*dst_ptr = (*src_ptr) & 0xf;
}
} else if (!cont_align && nocont_align) {
uint8_t* dst_ptr = (uint8_t*)cont_ptr;
uint8_t* src_ptr = (uint8_t*)nocont_ptr;
for (size_t i = 0; i < size_byte; ++i) {
uint8_t dst_low = *dst_ptr;
uint8_t src_all = *src_ptr;
uint8_t last = (dst_low & 0xf) | (src_all & 0xf) << 4;
uint8_t now = ((src_all & 0xf0) >> 4) & 0xf;
*dst_ptr = last;
++dst_ptr;
*dst_ptr = now;
++src_ptr;
}
if (!size_align) {
uint8_t dst_low = *dst_ptr;
uint8_t src_all = *src_ptr;
uint8_t last = (dst_low & 0xf) | (src_all & 0xf) << 4;
*dst_ptr = last;
}
} else if (cont_align && !nocont_align) {
uint8_t* dst_ptr = (uint8_t*)cont_ptr;
uint8_t* src_ptr = (uint8_t*)nocont_ptr;
for (size_t i = 0; i < size_byte; ++i) {
uint8_t src_last_high = *src_ptr;
++src_ptr;
uint8_t src_low = *src_ptr;
uint8_t rst = (src_low & 0xf) << 4 | ((src_last_high >> 4) & 0xf);
*dst_ptr = rst;
++dst_ptr;
}
if (!size_align) {
uint8_t src_last_high = *src_ptr;
*dst_ptr = ((src_last_high >> 4) & 0xf);
}
} else {
uint8_t* dst_ptr = (uint8_t*)cont_ptr;
uint8_t* src_ptr = (uint8_t*)nocont_ptr;
{
uint8_t src_last_high = *src_ptr;
uint8_t dst_last_low = *dst_ptr;
uint8_t rst = (dst_last_low & 0xf) | (src_last_high & 0xf0);
*dst_ptr = rst;
++dst_ptr;
++src_ptr;
}
if (!size_align) {
memcpy(dst_ptr, src_ptr, size_byte);
} else {
if (size_byte > 1) {
size_t align_size = size_byte - 1;
memcpy(dst_ptr, src_ptr, align_size);
dst_ptr += align_size;
src_ptr += align_size;
}
uint8_t src_last_low = *src_ptr;
*dst_ptr = src_last_low & 0xf;
}
}
}
void memcpy_cont2noncont_4bit(
void* cont, void* non_cont, size_t cont_offset, size_t nocont_offset,
size_t size) {
memcpy_4bit(non_cont, cont, nocont_offset, cont_offset, size);
}
void memcpy_noncont2cont_4bit(
void* cont, void* non_cont, size_t cont_offset, size_t nocont_offset,
size_t size) {
memcpy_4bit(cont, non_cont, cont_offset, nocont_offset, size);
}
template <typename T>
void call_transpose(
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst,
......@@ -46,7 +258,7 @@ void call_transpose(
}
//! one operand contiguous, and the other non-contiguous
template <typename ctype>
template <int bits>
void dispatch_on_dtype_cont(
Handle* handle, const TensorND& cont, const TensorND& nonc,
memcpy_policy_t mcp_pol) {
......@@ -54,13 +266,13 @@ void dispatch_on_dtype_cont(
switch (nonc.layout.ndim) {
case 2: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1];
auto strd0_n = nonc.layout.stride[0] * sizeof(ctype);
auto strd0_c = shp1 * sizeof(ctype);
auto strd0_n = nonc.layout.stride[0] * bits / 8;
auto strd0_c = shp1 * bits / 8;
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto cur_ncptr = static_cast<uint8_t*>(nonc.raw_ptr());
for (size_t i = 0; i < shp0; ++i) {
mcp_pol(cur_ctptr, cur_ncptr, strd0_c);
mcp_pol(cur_ctptr, cur_ncptr, 0, 0, strd0_c);
cur_ctptr += strd0_c;
cur_ncptr += strd0_n;
}
......@@ -70,16 +282,16 @@ void dispatch_on_dtype_cont(
case 3: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1],
shp2 = nonc.layout.shape[2];
auto strd0_n = nonc.layout.stride[0] * sizeof(ctype),
strd1_n = nonc.layout.stride[1] * sizeof(ctype);
auto strd1_c = shp2 * sizeof(ctype);
auto strd0_n = nonc.layout.stride[0] * bits / 8,
strd1_n = nonc.layout.stride[1] * bits / 8;
auto strd1_c = shp2 * bits / 8;
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto ncptr_row = static_cast<uint8_t*>(nonc.raw_ptr());
for (size_t i = 0; i < shp0; ++i) {
auto cur_ncptr = ncptr_row;
for (size_t j = 0; j < shp1; ++j) {
mcp_pol(cur_ctptr, cur_ncptr, strd1_c);
mcp_pol(cur_ctptr, cur_ncptr, 0, 0, strd1_c);
cur_ctptr += strd1_c;
cur_ncptr += strd1_n;
}
......@@ -95,13 +307,64 @@ void dispatch_on_dtype_cont(
static_cast<naive::HandleImpl*>(handle)->dispatch_kern(std::move(kern));
}
template <>
void dispatch_on_dtype_cont<4>(
Handle* handle, const TensorND& cont, const TensorND& nonc,
memcpy_policy_t mcp_pol) {
thin_function<void()> kern;
switch (nonc.layout.ndim) {
case 2: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1];
auto strd0_n = nonc.layout.stride[0];
auto strd0_c = shp1;
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto cur_ncptr = static_cast<uint8_t*>(nonc.raw_ptr());
size_t c_cnt = 0;
size_t n_cnt = 0;
for (size_t i = 0; i < shp0; ++i) {
mcp_pol(cur_ctptr, cur_ncptr, c_cnt, n_cnt, strd0_c);
c_cnt += strd0_c;
n_cnt += strd0_n;
}
};
break;
}
case 3: {
auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1],
shp2 = nonc.layout.shape[2];
auto strd0_n = nonc.layout.stride[0], strd1_n = nonc.layout.stride[1];
auto strd1_c = shp2;
kern = [=]() {
auto cur_ctptr = static_cast<uint8_t*>(cont.raw_ptr());
auto ncptr_row = static_cast<uint8_t*>(nonc.raw_ptr());
size_t c_cnt = 0;
size_t n_cnt = 0;
for (size_t i = 0; i < shp0; ++i) {
n_cnt = i * strd0_n;
for (size_t j = 0; j < shp1; ++j) {
mcp_pol(cur_ctptr, ncptr_row, c_cnt, n_cnt, strd1_c);
c_cnt += strd1_c;
n_cnt += strd1_n;
}
}
};
break;
}
default:
megdnn_assert(0);
}
static_cast<naive::HandleImpl*>(handle)->dispatch_kern(std::move(kern));
}
void dispatch_cont(
Handle* handle, const TensorND& cont, const TensorND& nonc,
memcpy_policy_t mcp_pol) {
switch (cont.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<dtype::_dt>::enumv: \
return dispatch_on_dtype_cont<equiv_ctype<dtype::_dt>::type>( \
return dispatch_on_dtype_cont<DTypeTrait<dtype::_dt>::bits>( \
handle, cont, nonc, mcp_pol);
MEGDNN_FOREACH_DTYPE_NAME(cb)
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)
......@@ -110,8 +373,8 @@ void dispatch_cont(
}
}
const size_t BLOCK_SIZE = 16,
TRANSPOSE_CV_MAX_C = relayout::transpose_fallback::BLOCK_LINE_SIZE_BYTES;
const size_t BLOCK_SIZE = 16;
const size_t TRANSPOSE_CV_MAX_C = relayout::transpose_fallback::BLOCK_LINE_SIZE_BYTES;
/*!
* \tparam ctype The type of the data
......@@ -221,28 +484,34 @@ void RelayoutForwardImpl::exec(
return;
}
// FIXME: optimize for lowbit cases
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
NaiveRelayoutForwardImpl::do_exec(src, dst);
return;
}
bool is_bit4 = is_int4(src.layout);
bool allow_nocontig = !is_bit4;
relayout::TransposeParam trans_param;
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true);
bool trans =
relayout::is_transpose(src.layout, dst.layout, trans_param, allow_nocontig);
trans = check_dtype_support_transparam(trans, is_bit4, trans_param);
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr);
}
void RelayoutForwardImpl::exec_after_preprocess(
const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) {
if (transpose) {
auto kernel = [tparam = *transpose, src, dst]() {
bool is_bit4 = is_int4(src.layout);
auto kernel = [tparam = *transpose, src, dst, is_bit4]() {
auto t = tparam;
auto dsize = src.layout.dtype.size() * t.c;
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) =
nullptr;
auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr()),
dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr());
size_t dsize = 0;
if (is_bit4) {
dsize = t.c >> 1;
} else {
dsize = src.layout.dtype.size() * t.c;
}
if (is_bit4 && dsize == 0) {
kptr = call_transpose<dt_qint4>;
} else {
if (dsize == 1) {
megdnn_assert(t.c == 1);
kptr = call_transpose<uint8_t>;
......@@ -285,6 +554,7 @@ void RelayoutForwardImpl::exec_after_preprocess(
}
megdnn_assert(kptr);
}
}
if (kptr) {
auto sptr = src.raw_ptr();
......@@ -305,13 +575,20 @@ void RelayoutForwardImpl::exec_after_preprocess(
MEGDNN_DISPATCH_CPU_KERN_OPR(memcpy(dst.raw_ptr(), src.raw_ptr(), sz));
return;
}
memcpy_policy_t cpy_noncont2cont = memcpy_noncont2cont;
memcpy_policy_t cpy_cont2noncont = memcpy_cont2noncont;
bool is_bit4 = src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm;
if (is_bit4) {
cpy_noncont2cont = memcpy_noncont2cont_4bit;
cpy_cont2noncont = memcpy_cont2noncont_4bit;
}
if (is_contig(dst.layout) && is_lastdim_contig(src.layout)) {
return dispatch_cont(handle(), dst, src, memcpy_noncont2cont);
return dispatch_cont(handle(), dst, src, cpy_noncont2cont);
}
if (is_contig(src.layout) && is_lastdim_contig(dst.layout)) {
return dispatch_cont(handle(), src, dst, memcpy_cont2noncont);
return dispatch_cont(handle(), src, dst, cpy_cont2noncont);
}
NaiveRelayoutForwardImpl::do_exec(src, dst);
}
......
......@@ -98,7 +98,7 @@ template <typename Impl>
void copy_tensors(
const CheckerHelper::TensorValueArray& dest,
const CheckerHelper::TensorValueArray& src, const Impl& copy_impl) {
megdnn_assert(dest.size() == src.size());
megdnn_assert(dest.size() == src.size(), "%zu != %zu", dest.size(), src.size());
for (size_t i = 0; i < src.size(); i++) {
auto&& tensor = src[i];
if (tensor.layout.ndim == 0)
......
......@@ -34,6 +34,60 @@ TEST_F(FALLBACK, RELAYOUT_RECORD) {
checker.exec({{2, 2, 2}, {2, 2, 2}});
}
TEST_F(FALLBACK, RELAYOUT_Q4) {
Checker<Relayout> checker(handle());
UniformIntRNG rng_int4{-7, 7};
checker.set_rng(0, &rng_int4)
.set_rng(1, &rng_int4)
.set_dtype(0, dtype::QuantizedS4(1.f))
.set_dtype(1, dtype::QuantizedS4(1.f))
.execs({{2, 2, 1, 1}, {1, 1, 2, 2}})
.execs({{1, 64, 15, 15}, {1, 15, 15, 64}})
.execs({{1, 5, 9, 32}, {1, 5, 32, 9}})
.execl(TensorLayoutArray{
{{6400}, {1}, dtype::QuantizedS4{1.f}},
{{20, 320}, {1024, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{156}, {1}, dtype::QuantizedS4{1.f}},
{{13, 3, 4}, {16, 1, 4}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{48}, {1}, dtype::QuantizedS4{1.f}},
{{3, 4, 4}, {16, 1, 4}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{84}, {1}, dtype::QuantizedS4{1.f}},
{{3, 4, 7}, {28, 1, 4}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{336}, {1}, dtype::QuantizedS4{1.f}},
{{3, 4, 7, 4}, {112, 4, 16, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{54}, {1}, dtype::QuantizedS4{1.f}},
{{6, 3, 3}, {16, 4, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{1200, 3}, {4, 1}, dtype::QuantizedS4{1.f}},
{{20, 60, 3}, {256, 4, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{20, 20, 3, 3}, {256, 12, 4, 1}, dtype::QuantizedS4{1.f}},
{{1200, 3}, {4, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{5, 16, 7, 7, 4}, {3136, 196, 28, 4, 1}, dtype::QuantizedS4{1.f}},
{{5, 16, 7, 7, 4}, {3136, 4, 448, 64, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{5, 7, 7, 16, 4}, {3136, 448, 64, 4, 1}, dtype::QuantizedS4{1.f}},
{{5, 7, 7, 16, 4}, {3136, 28, 4, 196, 1}, dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{5, 2, 7, 7, 32},
{3136, 1568, 224, 32, 1},
dtype::QuantizedS4{1.f}},
{{5, 2, 7, 7, 32},
{3136, 32, 448, 64, 1},
dtype::QuantizedS4{1.f}}})
.execl(TensorLayoutArray{
{{5, 7, 7, 2, 32}, {3136, 448, 64, 32, 1}, dtype::QuantizedS4{1.f}},
{{5, 7, 7, 2, 32},
{3136, 224, 32, 1568, 1},
dtype::QuantizedS4{1.f}}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_RELAYOUT_CV) {
relayout::run_cv_benchmark(handle());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册