提交 4fe68ac9 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): support transforming layout between nchw and nchw64 when channel not aligned to 64

GitOrigin-RevId: e9ecbcf2e25ce093e61aa6c8bab8909974e288d4
上级 ae6ff2c5
......@@ -252,10 +252,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(dst[1] % param().group == 0);
break;
case Param::Mode::NCHW_NCHW64:
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0);
megdnn_assert(src.ndim == 4);
dst.ndim = 5;
dst[0] = src[0];
dst[1] = src[1] / 64;
dst[1] = div_ceil(src[1], 64_z);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 64;
......@@ -264,7 +264,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(src.ndim == 5);
dst.ndim = 4;
dst[0] = src[0];
dst[1] = src[1] * 64;
dst[1] = param().oc == 0 ? src[1] * 64 : param().oc;
dst[2] = src[2];
dst[3] = src[3];
break;
......@@ -483,12 +483,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
case Param::Mode::NCHW4_NCHW:
// nchw to nchw4
{
megdnn_assert(src.format == dst.format);
exec_workspace =
TensorLayout({src[0], src[1] * 4, src[2], src[3]},
src.dtype, src.format)
.reshape({src[0], src[1], 4, src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_src = src;
dst.dtype, dst.format);
exec_src = src.dimshuffle({0, 1, 4, 2, 3});
exec_dst = dst;
}
break;
......@@ -658,13 +657,20 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
case Param::Mode::NCHW_NCHW64:
// src is {N, C, H, W}
// dst is {N, C/64, H, W, 64}
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]})
exec_workspace = TensorLayout(
{src[0], round_up(src[1], 64_z), src[2], src[3]},
src.dtype);
exec_src = exec_workspace
.reshape({src[0], div_ceil(src[1], 64_z), 64,
src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst;
break;
case Param::Mode::NCHW64_NCHW:
// src is {N, C/64, H, W, 64}
// dst is {N, C, H, W}
exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]},
dst.dtype);
exec_src = src.dimshuffle({0, 1, 4, 2, 3});
exec_dst = dst;
break;
......
/**
* \file dnn/src/cuda/relayout_format/helper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
namespace megdnn {
namespace cuda {
namespace relayout_format {
#define devfunc __forceinline__ __device__
template <int size_nbits>
devfunc int make_zero(int zero_point);
template <>
devfunc int make_zero<4>(int zero_point) {
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point,
zero_point, zero_point, zero_point,
zero_point, zero_point);
}
template <typename AccessType, int LoadBytes>
struct global_load_with_zero_point;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 32> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4* data = reinterpret_cast<uint4*>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %9, 0;\n"
" mov.b32 %0, %10;\n"
" mov.b32 %1, %10;\n"
" mov.b32 %2, %10;\n"
" mov.b32 %3, %10;\n"
" mov.b32 %4, %10;\n"
" mov.b32 %5, %10;\n"
" mov.b32 %6, %10;\n"
" mov.b32 %7, %10;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z),
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y),
"=r"(data[1].z), "=r"(data[1].w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)),
"l"(((uint8_t*)ptr) + 16));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 16> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4& data = reinterpret_cast<uint4&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" mov.b32 %0, %6;\n"
" mov.b32 %1, %6;\n"
" mov.b32 %2, %6;\n"
" mov.b32 %3, %6;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 8> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint2& data = reinterpret_cast<uint2&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" mov.b32 %0, %4;\n"
" mov.b32 %1, %4;\n"
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data.x), "=r"(data.y)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 4> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
unsigned& data = reinterpret_cast<unsigned&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" mov.b32 %0, %3;\n"
" @p ld.global.u32 %0, [%1];\n"
"}\n"
: "=r"(data)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 1> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
if (pred_guard)
D = *(reinterpret_cast<AccessType const*>(ptr));
else {
unsigned uv = reinterpret_cast<unsigned&>(zero_point);
uint8_t& data = reinterpret_cast<uint8_t&>(D);
data = uv & 0xff;
}
}
};
#undef devfunc
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
......@@ -18,6 +18,7 @@
#pragma GCC diagnostic pop
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/relayout_format/relayout_format.cuh"
#include "src/cuda/relayout_format/helper.cuh"
using namespace megdnn;
using namespace cuda;
......@@ -728,17 +729,18 @@ struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm,
#undef pack
template <typename DstType>
inline __device__ DstType make_zero_pad(const char zero_point) {
inline __device__ DstType make_zero_pad(const uint8_t zero_point) {
return zero_point;
}
template <>
inline __device__ char4 make_zero_pad<char4>(const char zero_point) {
return {zero_point, zero_point, zero_point, zero_point};
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) {
char izp = reinterpret_cast<const char&>(zero_point);
return {izp, izp, izp, izp};
}
template <>
inline __device__ int4 make_zero_pad<int4>(const char zero_point) {
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) {
return {zero_point, zero_point, zero_point, zero_point};
}
......@@ -767,7 +769,7 @@ inline __device__ void write_helper<array_wrapper<uint8_t, 32>>(
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x),
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
};
}
template <bool with_pad, int pack_w, int pack_c, bool same_scale, bool all_pad,
typename SrcType, typename DstType, typename DnnSrcType,
......@@ -825,7 +827,7 @@ struct RelayoutKern {
const SrcType* src, DstType* dst, const int ic_stride,
const int remain_ic,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process,
const char zero_point) {
const uint8_t zero_point) {
InnerDtype read_channel[pack_c];
if (all_pad) {
const InnerDtype zero_pad = make_zero_pad<InnerDtype>(zero_point);
......@@ -855,7 +857,7 @@ __global__ void kern_nchw_nchwx(
const SrcType* src, DstType* dst, int in_n, int ic, int ihw,
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process,
const char zero_point, const int group, const int ocpg) {
const uint8_t zero_point, const int group, const int ocpg) {
static constexpr int size_src_type = sizeof(SrcType);
static constexpr int size_dst_type = sizeof(DstType);
#ifndef MEGDNN_COMMA
......@@ -1072,6 +1074,7 @@ public:
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements +
hw_idx * pack_size * size_nbits / (8 * sizeof(Type));
channel -= c_idx;
}
MEGDNN_DEVICE __forceinline__ void add_pointer_offset(
......@@ -1079,7 +1082,7 @@ public:
pointer += offset_in_type;
}
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) {
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
......@@ -1090,11 +1093,12 @@ public:
(lane_size_in_type / pack_size_in_type) +
j;
bool guard = i < channel;
cutlass::arch::global_load<AccessType, pack_size_in_byte>(
relayout_format::global_load_with_zero_point<AccessType,
pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ +
j * pack_size_in_type),
guard);
guard, zero_point);
}
pointer_ += chan_stride_in_elements;
}
......@@ -1173,6 +1177,7 @@ public:
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements;
channel -= c_idx;
#pragma unroll
for (int i = 0; i < mask_size; ++i) {
mask[i] = 0;
......@@ -1201,7 +1206,7 @@ public:
pointer += offset_in_type;
}
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) {
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
......@@ -1214,9 +1219,11 @@ public:
int mask_index = (frag_idx >> 5);
int mask_shift = (frag_idx & 0x1f);
bool guard = (mask[mask_index] & (1 << mask_shift));
cutlass::arch::global_load<AccessType, pack_size_in_byte>(
relayout_format::global_load_with_zero_point<AccessType,
pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ + stride[j]), guard);
reinterpret_cast<void*>(pointer_ + stride[j]), guard,
zero_point);
}
pointer_ += chan_stride_in_elements;
}
......@@ -1306,11 +1313,13 @@ struct RelayoutProblem {
int batch_size;
int channels;
int hw;
int zero_point;
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_,
DstIterator dst_iterator_,
CudaPostProcess post_process_,
int n_stride_src_, int n_stride_dst_,
int batch_size_, int channels_, int hw_)
int batch_size_, int channels_, int hw_,
int zero_point_)
: src_iterator{src_iterator_},
dst_iterator{dst_iterator_},
post_process{post_process_},
......@@ -1318,7 +1327,8 @@ struct RelayoutProblem {
n_stride_dst{n_stride_dst_},
batch_size{batch_size_},
channels{channels_},
hw{hw_} {}
hw{hw_},
zero_point{zero_point_} {}
};
};
......@@ -1345,7 +1355,9 @@ __global__ void relayout_kern(typename RelayoutProblem_::Param param) {
param.dst_iterator.initialize(c_idx, hw_idx);
typename SrcIterator::Fragment src_frag;
typename DstIterator::Fragment dst_frag;
param.src_iterator.load(src_frag);
int zp = relayout_format::make_zero<SrcIterator::size_nbits>(
param.zero_point);
param.src_iterator.load(src_frag, zp);
RelayoutProblem_::Transpose::trans(
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag),
src_frag, param.post_process);
......@@ -1382,7 +1394,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
stype.name(), dtype.name());
#undef DEF
// no padding
if (src.layout.stride[2] == static_cast<ptrdiff_t>(src.layout[3])) {
if (stype.enumv().ev != DTypeEnum::Ev::QuantizedS4 &&
stype.enumv().ev != DTypeEnum::Ev::Quantized4Asymm) {
const int in_n = src.layout[0];
const int out_n = dst.layout[0];
const int ic = src.layout[1];
......@@ -1428,18 +1441,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \
DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \
DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8);
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4);
DISPATCH_INT(QuantizedS32, QuantizedS32);
DISPATCH_BYTE(Uint8, QuantizedS8);
DISPATCH_BYTE(Quantized8Asymm, QuantizedS8);
DISPATCH_BYTE(QuantizedS8, QuantizedS8);
DISPATCH_4BITS(QuantizedS4, QuantizedS4);
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm);
#undef DISPATCH_4BITS
#undef DISPATCH_BYTE
#undef DISPATCH_INT
#undef DISPATCH_RAW
......@@ -1450,7 +1455,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
} else {
megdnn_assert(src_layout.dtype.is_low_bit());
int n = src.layout[0];
int c = src.layout[1];
int ic = src.layout[1];
int oc = dst.layout[1] * 64;
int h = src.layout[2];
// align to byte
int w = src.layout[3];
......@@ -1460,12 +1466,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
int ic_stride = src_layout.stride[1];
int n_stride_dst = dst_layout.stride[0];
int oc_stride = dst_layout.stride[1];
int problem_size = n * (c / pack_oc) * hw;
int problem_size = n * (oc / pack_oc) * hw;
bool same_scale = src_scale == dst_scale;
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \
_src_c_type, _dst_c_type, _size_nbits) \
if (same_scale == _same_scale && hw % _pack_w == 0 && \
stype.enumv().ev == DTypeEnum::Ev::_src_type && \
bool padding = w % 2 != 0;
#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \
_dst_type, _src_c_type, _dst_c_type, _size_nbits) \
if (padding == _padding && same_scale == _same_scale && \
hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \
using InnerDtype_ = typename DTypeRWHelper< \
typename DTypeTrait<dtype::_src_type>::ctype, \
......@@ -1473,8 +1480,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
using SrcIterator_ = \
TensorIteratorOverChannel<InnerDtype_, 1, _pack_oc, _pack_w, \
_size_nbits>; \
using DstIterator_ = MaskedTensorIteratorOverChannel< \
_dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \
using DstIterator_ = \
typename TensorIteratorPolicy<_padding, _dst_c_type, _pack_oc, \
_pack_oc, _pack_w, \
_size_nbits>::TensorIterator; \
using CudaPostProcess_ = \
CudaPostProcess<dtype::_src_type, dtype::_dst_type, \
_same_scale>; \
......@@ -1489,17 +1498,18 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \
oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \
typename RelayoutProblem_::Param param{ \
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \
SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \
w_pad}, \
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \
DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, oc, w, \
w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \
n_stride_src, \
n_stride_dst, \
n, \
c, \
hw}; \
oc, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
......@@ -1507,11 +1517,15 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
const dim3 thread_dim(nr_threads); \
return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \
}
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4);
#define DISPATCH_4BITS(_src_type, _dst_type) \
DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \
DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4);
DISPATCH_4BITS(QuantizedS4, QuantizedS4);
DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm);
#undef DISPATCH_4BITS
......@@ -1521,7 +1535,6 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).",
stype.name(), dtype.name(), h, w);
}
after_kernel_launch();
}
bool relayout_format::relayout_format_cuda_usable(
......@@ -1568,7 +1581,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic);
#undef DEF
int n = src.layout[0];
int c = src.layout[1] * pack_ic;
int ic = src.layout[1] * pack_ic;
int h = src.layout[2];
// align to byte
int w = src.layout[3];
......@@ -1578,7 +1591,8 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
int ic_stride = src_layout.stride[1];
int n_stride_dst = dst_layout.stride[0];
int oc_stride = dst_layout.stride[1];
int problem_size = n * (c / pack_ic) * hw;
int problem_size = n * (ic / pack_ic) * hw;
int oc = dst.layout[1];
bool same_scale = src_scale == dst_scale;
bool padding = w % 2 != 0;
......@@ -1611,17 +1625,18 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \
oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \
typename RelayoutProblem_::Param param{ \
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \
SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, ic, w, \
w_pad}, \
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \
DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \
w_pad}, \
CudaPostProcess_{src_scale, src_zero_point, dst_scale, \
dst_zero_point}, \
n_stride_src, \
n_stride_dst, \
n, \
c, \
hw}; \
ic, \
hw, \
src_zero_point}; \
auto kernel = relayout_kern<RelayoutProblem_>; \
int nr_threads = query_blocksize_for_kernel(kernel); \
nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \
......@@ -1645,7 +1660,6 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
megdnn_assert(false,
"Unsupported data type(src:%s, dst:%s) or image size(%dx%d).",
stype.name(), dtype.name(), h, w);
after_kernel_launch();
}
void relayout_format::relayout_format_cuda_nchw4_nchw(
......
......@@ -21,6 +21,7 @@
#include "cuda.h"
#include "src/cuda/cudnn_with_check.h"
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
#define cuda_check(_x) \
do { \
......@@ -448,13 +449,12 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
template <bool signedness, typename T>
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
int bits) {
uint8_t result = (uint8_t)((storage >> bits) & 0xf);
if (signedness) {
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask))
: (int)(result);
}
return int(result);
static constexpr int shift = 28;
using type = typename cutlass::platform::conditional<signedness, int,
unsigned>::type;
unsigned intermediate = static_cast<unsigned>(storage);
type result = reinterpret_cast<type&>(intermediate);
return (result << (shift - bits)) >> shift;
}
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(
......
......@@ -42,6 +42,36 @@ void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0,
}
}
template <size_t size_nbits>
void lowbit_recursive_cp(const TensorND& dst, const TensorND& src,
size_t idx = 0, size_t src_offset = 0,
size_t dst_offset = 0) {
MEGDNN_STATIC_ASSERT(!(8_z % size_nbits),
"size in bits of lowbit data type can only be 1, 2, 4 "
"or 8");
if (idx < (src.layout.ndim - 1)) {
for (size_t i = 0; i < src.layout[idx]; ++i) {
lowbit_recursive_cp<size_nbits>(
dst, src, idx + 1, src_offset + i * src.layout.stride[idx],
dst_offset + i * dst.layout.stride[idx]);
}
} else {
megdnn_assert(src.layout.stride[idx] == 1);
megdnn_assert(dst.layout.stride[idx] == 1);
size_t dim_bytes = div_ceil(src.layout[idx], 8_z / size_nbits);
// offset in elements
uint8_t* dptr = reinterpret_cast<uint8_t*>(dst.raw_ptr) +
(dst_offset * size_nbits / 8);
uint8_t* sptr = reinterpret_cast<uint8_t*>(src.raw_ptr) +
(src_offset * size_nbits / 8);
for (size_t i = 0; i < dim_bytes; ++i) {
*dptr = *sptr;
dptr++;
sptr++;
}
}
}
void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) {
switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \
......@@ -54,10 +84,17 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) {
cb(Int32, dt_int32);
cb(QuantizedS32, dt_int32);
cb(QuantizedS8, dt_qint8);
#undef cb
#define cb(name, size_nbits) \
case (DTypeEnum::name): { \
lowbit_recursive_cp<size_nbits>(dst, src); \
break; \
}
cb(QuantizedS4, 4);
cb(Quantized4Asymm, 4);
#undef cb
default:
megdnn_assert(0, "not support dtype %s", src.layout.dtype.name());
#undef cb
}
}
......@@ -66,24 +103,27 @@ void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(),
"dst %s, src %s", dst.layout.to_string().c_str(),
src.layout.to_string().c_str());
const size_t type_size = dst.layout.dtype.size();
const size_t n = dst.layout[0];
const size_t n_stride_dst = dst.layout.stride[0];
const size_t n_stride_src = src.layout.stride[0];
const size_t n_stride_dst_in_bytes =
dst.layout.dtype.size(dst.layout.stride[0]);
const size_t n_stride_src_in_bytes =
src.layout.dtype.size(src.layout.stride[0]);
const size_t ocpg = dst.layout[1] / group;
const size_t icpg = src.layout[1] / group;
const size_t dst_hw = dst.layout[2] * dst.layout[3];
const size_t src_hw = src.layout[2] * src.layout[3];
megdnn_assert(dst_hw == src_hw);
const size_t dst_c_stride_in_bytes =
dst.layout.dtype.size(dst.layout.stride[1]);
const size_t src_c_stride_in_bytes =
src.layout.dtype.size(src.layout.stride[1]);
megdnn_assert(dst_c_stride_in_bytes == src_c_stride_in_bytes);
for (size_t nid = 0; nid < n; ++nid) {
const size_t n_offset_dst = nid * n_stride_dst * type_size;
const size_t n_offset_src = nid * n_stride_src * type_size;
const size_t n_offset_dst = nid * n_stride_dst_in_bytes;
const size_t n_offset_src = nid * n_stride_src_in_bytes;
for (size_t gid = 0; gid < group; ++gid) {
memcpy((char*)dst.raw_ptr + n_offset_dst +
gid * ocpg * dst_hw * type_size,
gid * ocpg * dst_c_stride_in_bytes,
(char*)src.raw_ptr + n_offset_src +
gid * icpg * src_hw * type_size,
ocpg * dst_hw * type_size);
gid * icpg * src_c_stride_in_bytes,
ocpg * dst_c_stride_in_bytes);
}
}
};
......@@ -415,6 +455,30 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return oc * ic * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW64: {
if (src[1] % 64 != 0) {
size_t n = src[0];
size_t c = round_up(src[1], 64_z);
size_t h = src[2];
size_t w = src[3];
TensorLayout wsly({n, c, h, w}, src.dtype);
return wsly.span().dist_byte();
}
return 0_z;
}
case Param::Mode::NCHW64_NCHW: {
if (param().oc != 0) {
size_t n = src[0];
size_t c = src[1] * 64;
size_t h = src[2];
size_t w = src[3];
TensorLayout wsly({n, c, h, w}, dst.dtype);
return wsly.span().dist_byte();
}
return 0_z;
}
default:
return 0;
}
......@@ -437,6 +501,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
// clean dst
MEGDNN_DISPATCH_CPU_KERN(
m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte()));
// pre
if (param().mode == Param::Mode::NCHW_NHWCD4I) {
size_t N = src.layout[0];
size_t IC = src.layout[1];
......@@ -551,6 +616,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout);
}
} else if (param().mode == Param::Mode::NCHW_NCHW64) {
MIDOUT_BEGIN(megdnn_naive_relayout_format,
midout_iv(Param::Mode::NCHW_NCHW64)) {
size_t c = src.layout[1];
if (c % 64 != 0) {
uint8_t zp = 0;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
zp = src.layout.dtype.param<dtype::Quantized4Asymm>()
.zero_point;
zp = (zp & 0xf) | (zp << 4);
}
MEGDNN_DISPATCH_CPU_KERN(
m_handle, memset(workspace.raw_ptr, zp,
exec_workspace.span().dist_byte()));
TensorND ws_nd(workspace.raw_ptr, exec_workspace);
MEGDNN_DISPATCH_CPU_KERN(m_handle,
padding_to_workspace(ws_nd, src););
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
}
MIDOUT_END();
} else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
......@@ -574,24 +660,16 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
cb(1, 2, 4, NCHW_NCHW4_WEIGHT);
}
} else if (param().mode == Param::Mode::NCHW4_NCHW) {
if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) {
m_handle->relayout_opr()->exec(
exec_src_nd, {dst.raw_ptr, exec_workspace}, handle());
return;
} else {
m_handle->relayout_opr()->exec(
exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle());
TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4,
src.layout[2], src.layout[3]},
src.layout.dtype,
src.layout.format};
extract_from_workspace(exec_dst_nd,
{workspace.raw_ptr, workspace_layout},
param().group);
return;
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
exec_dst_nd = {workspace.raw_ptr, exec_workspace};
}
} else if (param().mode == Param::Mode::NCHW64_NCHW) {
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
exec_dst_nd = {workspace.raw_ptr, exec_workspace};
}
}
// do relayout
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -600,7 +678,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_qu8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -609,7 +686,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_u8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -618,7 +694,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q8_q8(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -627,7 +702,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q32_q32(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -636,7 +710,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_q4_q4(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
......@@ -645,9 +718,20 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
do_copy_diff_qu4_qu4(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else {
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
}
// post
if (param().mode == Param::Mode::NCHW4_NCHW ||
param().mode == Param::Mode::NCHW64_NCHW) {
if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) {
megdnn_assert(exec_workspace.dtype == dst.layout.dtype);
TensorND ws_nd{workspace.raw_ptr, exec_workspace};
MEGDNN_DISPATCH_CPU_KERN(
m_handle,
extract_from_workspace(dst, ws_nd, param().group););
}
}
#undef cb
}
......
......@@ -18,7 +18,6 @@
using namespace megdnn;
using namespace test;
#define MEGDNN_WITH_BENCHMARK 1
TEST_F(CUDA, RELAYOUT_FORMAT) {
Checker<RelayoutFormat> checker(handle_cuda());
......@@ -245,7 +244,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) {
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64;
for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t c : {15, 64, 128}) {
for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 3, 7, 8, 16, 31}) {
checker.set_dtype(0, dtype::QuantizedS4{2.f})
......@@ -285,36 +284,41 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) {
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;
for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t c : {15, 64, 128}) {
for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 3, 4, 7, 14, 16, 17}) {
if (c % 64 != 0) {
param.oc = c;
} else {
param.oc = 0;
}
checker.set_dtype(0, dtype::QuantizedS4{2.f})
.set_dtype(1, dtype::QuantizedS4{2.f})
.set_rng(0, &s4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4})
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8})
.set_rng(0, &u4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.f})
.set_rng(0, &s4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8})
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4})
.set_rng(0, &u4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
.execs({{n, (c + 63) / 64, h, w, 64}, {}});
}
}
}
......@@ -375,10 +379,14 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
CUBenchmarker<RelayoutFormat> benchmarker(handle_cuda());
benchmarker.set_param(param);
benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.20210322f});
.set_dtype(1, dtype::QuantizedS4{1.19990307f});
for (auto&& shape : shapes) {
double memaccess = double(shape.total_nr_elems()) * 1e-6;
double memaccess =
double(TensorLayout(shape, dtype::QuantizedS4{1.f})
.span()
.dist_byte()) *
2e-6;
auto time_ms = benchmarker.execs({shape, {}});
printf("execute %s, time %.4f ms, %.4f GB/s\n",
shape.to_string().c_str(), time_ms, memaccess / time_ms);
......@@ -387,8 +395,9 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
{
TensorShapeArray shapes = {
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56},
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55},
{1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56},
{1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55},
{1, 256, 384, 640},
};
Param param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64;
......@@ -399,7 +408,8 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) {
{64, 1, 56, 56, 64},
{1, 32, 7, 7, 64},
{16, 32, 7, 7, 64},
{64, 32, 7, 7, 64},
{64, 32, 7, 7, 64},
{1, 4, 384, 640, 64},
};
Param param;
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册