提交 19a554d6 编写于 作者: M Megvii Engine Team

test(dnn/cuda): add testcase for transforming tensor layout between nchw and nchw64

GitOrigin-RevId: 75d579635ad177d9391b8da6ca45fab1086d3f6a
上级 71c2f612
......@@ -161,7 +161,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle(
ws_size_underlying_algo, ws_size_z}};
}
return WorkspaceBundle{raw_ptr,
{ws_size_src, ws_size_filter,
ws_size_underlying_algo, ws_size_dst}};
{ws_size_src, ws_size_filter, ws_size_dst,
ws_size_underlying_algo}};
}
// vim: syntax=cpp.doxygen
......@@ -30,7 +30,10 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT,
Param::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT ||
param().mode == Param::Mode::NCHW_NCHW64 ||
param().mode == Param::Mode::NCHW64_NCHW,
"relayout format of cuda only support NCHW4->CHWN4 or "
"CHWN4->NCHW4 or NCHW->NCHW4");
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
......
......@@ -26,6 +26,9 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale,
scale = tensor_dtype.param<dtype::QuantizedS8>().scale;
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) {
scale = tensor_dtype.param<dtype::QuantizedS4>().scale;
} else if (tensor_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
zero_point = tensor_dtype.param<dtype::Quantized4Asymm>().zero_point;
scale = tensor_dtype.param<dtype::Quantized4Asymm>().scale;
}
}
......@@ -41,8 +44,6 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
cudaStream_t stream,
RelayoutFormat::Param::Mode mode,
int group) {
auto&& stype = src.layout.dtype;
auto&& dtype = dst.layout.dtype;
float src_scale = 1.f;
float dst_scale = 1.f;
uint8_t src_zero_point = 0;
......
......@@ -538,9 +538,9 @@ struct Translayout<64, 8, SrcType, dtype::QuantizedS4, dtype::QuantizedS4,
};
#undef pack
#define pack(_idx) \
((uint8_t)(post_process(intermediate[0][_idx])) | \
((uint8_t)(post_process(intermediate[1][_idx])) << 4))
#define pack(_idx) \
((post_process(intermediate[0][_idx]) & 0xf) | \
(post_process(intermediate[1][_idx]) << 4))
template <typename SrcType, bool same_scale>
struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4,
same_scale> {
......@@ -648,9 +648,9 @@ struct Translayout<64, 8, SrcType, dtype::Quantized4Asymm,
};
#undef pack
#define pack(_idx) \
((uint8_t)(post_process(intermediate[0][_idx])) | \
((uint8_t)(post_process(intermediate[1][_idx])) << 4))
#define pack(_idx) \
(post_process(intermediate[0][_idx]) | \
(post_process(intermediate[1][_idx]) << 4))
template <typename SrcType, bool same_scale>
struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm,
dtype::Quantized4Asymm, same_scale> {
......@@ -820,13 +820,25 @@ __global__ void kern_nchw_nchwx(
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride,
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process,
const char zero_point, const int group, const int ocpg) {
static constexpr int size_src_type = sizeof(SrcType);
static constexpr int size_dst_type = sizeof(DstType);
#ifndef MEGDNN_COMMA
#define MEGDNN_COMMA ,
#endif
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value,
"Currently this kernel only support accessing tensor "
"src and dst in same data type.");
n_stride_src /= size_src_type;
ic_stride /= size_src_type;
n_stride_dst /= size_dst_type;
oc_stride /= size_dst_type;
const int n_idx = blockIdx.y;
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int ihw_offset =
ihw_block_idx * pack_w;
const int ihw_offset_in_type =
ihw_offset * size_nbits / (8 * sizeof(SrcType));
ihw_offset * size_nbits / (8 * size_src_type);
if (ihw_offset < ihw) {
const int src_offset_base = n_idx * n_stride_src + ihw_offset_in_type;
const int dst_offset_base =
......@@ -836,7 +848,7 @@ __global__ void kern_nchw_nchwx(
const int ic_block = icpg / pack_c;
const int remain_ic = icpg % pack_c;
const int src_group_stride = icpg * ic_stride;
const int dst_group_stride = ocpg * oc_stride;
const int dst_group_stride = (ocpg / pack_c) * oc_stride;
for (int g_idx = 0; g_idx < group; ++g_idx) {
const int src_offset =
src_offset_base + g_idx * src_group_stride;
......@@ -1018,7 +1030,7 @@ public:
int chan_stride_in_elements_,
int channel_)
: pointer{pointer_},
chan_stride_in_elements{chan_stride_in_elements},
chan_stride_in_elements{chan_stride_in_elements_},
channel{channel_} {}
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) {
......@@ -1031,7 +1043,7 @@ public:
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
bool guard = i >= channel;
bool guard = i < channel;
cutlass::arch::global_load<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ +
......@@ -1052,7 +1064,7 @@ public:
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
bool guard = i >= channel;
bool guard = i < channel;
cutlass::arch::global_store<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ +
......@@ -1092,11 +1104,24 @@ __global__ void kern_nchwx_nchw(
size_nbits>;
using Transpose = Translayout<pack_c, pack_w, SrcType, DnnSrcType,
DnnDstType, same_scale>;
static constexpr int size_src_type = sizeof(SrcType);
static constexpr int size_dst_type = sizeof(DstType);
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value,
"Currently this kernel only support accessing tensor "
"src and dst in same data type.");
n_stride_src /= size_src_type;
ic_stride /= size_src_type;
n_stride_dst /= size_dst_type;
oc_stride /= size_dst_type;
#undef MEGDNN_COMMA
const int n_idx = blockIdx.y;
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int ihw_offset = ihw_block_idx * pack_w;
const int ihw_offset_in_type =
ihw_offset * size_nbits / (8 * sizeof(SrcType));
ihw_offset * size_nbits / (8 * size_src_type);
const int oc_stride_inner_dtype =
oc_stride * size_dst_type / sizeof(InnerDtype);
if (ihw_offset < ihw) {
const int ic_block = (ic + pack_c - 1) / pack_c;
const int src_offset_base =
......@@ -1105,8 +1130,8 @@ __global__ void kern_nchwx_nchw(
SrcIterator src_iterator{const_cast<SrcType*>(src + src_offset_base),
ic_stride, ic};
DstIteraotr dst_iterator{
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), oc_stride,
ic};
reinterpret_cast<InnerDtype*>(dst + dst_offset_base),
oc_stride_inner_dtype, ic};
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) {
typename SrcIterator::Fragment src_frag;
......@@ -1143,12 +1168,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
DEF(64, Quantized4Asymm, Quantized4Asymm)
DEF(4, QuantizedS8, QuantizedS8)
DEF(4, Uint8, QuantizedS8)
DEF(4, Quantized8Asymm, Quantized8Asymm)
DEF(4, QuantizedS32, QuantizedS32);
DEF(4, Quantized8Asymm, QuantizedS8)
DEF(4, QuantizedS32, QuantizedS32)
// clang-format on
megdnn_assert(pack_oc == 4 || pack_oc == 64,
"Unsupport pack size(pack_oc:%d)", pack_oc);
#undef DEF
"Unsupport pack size(pack_oc:%d, src:%s, dst:%s)", pack_oc,
stype.name(), dtype.name());
#undef DEF
const int in_n = src.layout[0];
const int out_n = dst.layout[0];
const int ic = src.layout[1];
......@@ -1157,6 +1183,7 @@ void relayout_format::relayout_format_cuda_nchw_nchwx(
const int oc = dst.layout[1] * pack_oc;
const int hw = h * w;
const int ocpg = oc / group;
// stride in byte
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]);
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]);
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]);
......@@ -1244,20 +1271,20 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
auto& src_layout = src.layout;
auto& dst_layout = dst.layout;
// check pack size
int pack_oc = std::numeric_limits<int>::min();
#define DEF(_pack_oc, _src_type, _dst_type) \
int pack_ic = std::numeric_limits<int>::min();
#define DEF(_pack_ic, _src_type, _dst_type) \
if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \
pack_oc = _pack_oc; \
pack_ic = _pack_ic; \
}
// clang-format off
DEF(64, QuantizedS4, QuantizedS4)
DEF(64, Quantized4Asymm, Quantized4Asymm)
// clang-format on
megdnn_assert(pack_oc == 64, "Unsupport pack size(pack_oc:%d)", pack_oc);
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic);
#undef DEF
const int n = src.layout[0];
const int c = src.layout[1];
const int c = src.layout[1] * pack_ic;
const int h = src.layout[2];
// align to byte
const int w = src.layout[3];
......@@ -1266,7 +1293,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw(
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]);
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]);
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]);
bool same_scale = src_scale == dst_scale;
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \
_src_c_type, _dst_c_type, _size_nbits) \
......
......@@ -378,7 +378,9 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval,
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
unsigned out;
#if __CUDA_ARCH__ >= 750
#if __CUDA_ARCH__ >= 750 && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;"
......@@ -411,7 +413,9 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8(
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
unsigned out;
#if __CUDA_ARCH__ >= 750
#if __CUDA_ARCH__ >= 750 && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
......
......@@ -226,6 +226,7 @@ void do_copy_diff_q8_q8(const TensorND& dst, const TensorND& src) {
++isrc;
}
}
void do_copy_diff_q32_q32(const TensorND& dst, const TensorND& src) {
auto isrc = tensor_iter_valonly<DTypeTrait<dtype::QuantizedS32>::ctype>(src)
.begin();
......@@ -253,6 +254,38 @@ void do_copy_diff_u8_q8(const TensorND& dst, const TensorND& src) {
}
}
void do_copy_diff_q4_q4(const TensorND& dst, const TensorND& src) {
auto isrc =
tensor_iter_valonly<DTypeTrait<dtype::QuantizedS4>::ctype>(src)
.begin();
auto idst =
tensor_iter_valonly<DTypeTrait<dtype::QuantizedS4>::ctype>(dst)
.begin();
auto src_dt_parm = src.layout.dtype.param<dtype::QuantizedS4>();
auto dst_dt_parm = dst.layout.dtype.param<dtype::QuantizedS4>();
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(int8_t(*isrc)));
++idst;
++isrc;
}
}
void do_copy_diff_qu4_qu4(const TensorND& dst, const TensorND& src) {
auto isrc =
tensor_iter_valonly<DTypeTrait<dtype::Quantized4Asymm>::ctype>(src)
.begin();
auto idst =
tensor_iter_valonly<DTypeTrait<dtype::Quantized4Asymm>::ctype>(dst)
.begin();
auto src_dt_parm = src.layout.dtype.param<dtype::Quantized4Asymm>();
auto dst_dt_parm = dst.layout.dtype.param<dtype::Quantized4Asymm>();
for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) {
*idst = dst_dt_parm.quantize(src_dt_parm.dequantize(uint8_t(*isrc)));
++idst;
++isrc;
}
}
void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) {
megdnn_assert(dst.is_non_overlapping_strong());
src = src.collapse_contiguous();
......@@ -595,6 +628,24 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
check_layout_and_canonize(src0.layout, src0.layout);
auto func = [](const TensorND& dst, const TensorND& src) {
do_copy_diff_q4_q4(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
TensorND src0 = exec_src_nd, dst0 = exec_dst_nd;
check_layout_and_canonize(src0.layout, src0.layout);
auto func = [](const TensorND& dst, const TensorND& src) {
do_copy_diff_qu4_qu4(dst, src);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0));
return;
} else {
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
}
......
......@@ -237,6 +237,89 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_IC_SMALL) {
.execs({{8, 3, 768, 1280}, {}});
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG s4{-8, 7};
UniformIntRNG u4{0, 15};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64;
for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 4, 14, 16}) {
checker.set_dtype(0, dtype::QuantizedS4{2.f})
.set_dtype(1, dtype::QuantizedS4{2.f})
.set_rng(0, &s4)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 8})
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 4})
.set_rng(0, &u4)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.f})
.set_rng(0, &s4)
.set_param(param)
.execs({{n, c, h, w}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.19990307f, 8})
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4})
.set_rng(0, &u4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c, h, w}, {}});
}
}
}
}
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG s4{-8, 7};
UniformIntRNG u4{0, 15};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;
for (size_t n : {1, 3}) {
for (size_t c : {64, 128}) {
for (size_t h : {7, 14, 16, 28}) {
for (size_t w : {2, 4, 14, 16}) {
checker.set_dtype(0, dtype::QuantizedS4{2.f})
.set_dtype(1, dtype::QuantizedS4{2.f})
.set_rng(0, &s4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4})
.set_dtype(1, dtype::Quantized4Asymm{1.2f, 8})
.set_rng(0, &u4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::QuantizedS4{1.19990307f})
.set_dtype(1, dtype::QuantizedS4{1.f})
.set_rng(0, &s4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8})
.set_dtype(1, dtype::Quantized4Asymm{1.f, 4})
.set_rng(0, &u4)
.set_param(param)
.set_epsilon(1e-3)
.execs({{n, c / 64, h, w, 64}, {}});
}
}
}
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) {
using Param = RelayoutFormat::Param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册