提交 71c2f612 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add relayout format to support layout transform between NCHW and NCHW64

GitOrigin-RevId: 1445ecfabe106eee57494b6bee4df14b9b81556b
上级 df009e89
......@@ -196,6 +196,32 @@ public:
const TensorLayout& layout) const override;
};
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
///*!
// * \brief used for tensors with lowbit data type
// *
// * \p SIZE_NBITS is the size in bits of element of the tensor.
// *
// */
//template <size_t SIZE_NBITS_>
//class LowbitTensorFormat : public TensorFormat::ImplBase {
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
// size_t m_align_size_in_bits;
//
//protected: //?
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits);
//
//public:
// size_t align_size_in_bits() const {
// return m_align_size_in_bits;
// }
//
// std::string to_string() const override;
//
// void serialize_append(
//
//
//};
} // namespace detail
/*!
......
......@@ -895,6 +895,7 @@ Relayout mode.
* ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
* ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
* ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
* ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
**Float weight transformation definitions**
......@@ -969,6 +970,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
'NCHW_NCHW4',
'NCHW4_NCHW',
'NCHW_NCHW4_WEIGHT',
'NCHW_NCHW64',
'NCHW64_NCHW',
)
)
......
......@@ -251,6 +251,23 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst[3] = src[3];
megdnn_assert(dst[1] % param().group == 0);
break;
case Param::Mode::NCHW_NCHW64:
megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0);
dst.ndim = 5;
dst[0] = src[0];
dst[1] = src[1] / 64;
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 64;
break;
case Param::Mode::NCHW64_NCHW:
megdnn_assert(src.ndim == 5);
dst.ndim = 4;
dst[0] = src[0];
dst[1] = src[1] * 64;
dst[2] = src[2];
dst[3] = src[3];
break;
default:
megdnn_assert(0, "Invalid RelayoutFormat Mode");
break;
......@@ -352,7 +369,12 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
CHECK_SRC(DefaultTensorFormat::make());
dst = src;
break;
case Param::Mode::NCHW_NCHW64:
dst = src;
break;
case Param::Mode::NCHW64_NCHW:
dst = src;
break;
default:
megdnn_throw("Invalid relayout format mode");
break;
......@@ -633,6 +655,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_src = src.dimshuffle({3, 0, 1, 2, 4});
exec_dst = dst;
break;
case Param::Mode::NCHW_NCHW64:
// src is {N, C, H, W}
// dst is {N, C/64, H, W, 64}
exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst;
break;
case Param::Mode::NCHW64_NCHW:
// src is {N, C/64, H, W, 64}
// dst is {N, C, H, W}
exec_src = src.dimshuffle({0, 1, 4, 2, 3});
exec_dst = dst;
break;
default:
megdnn_assert(0, "Invalid RelayoutFormat Mode");
}
......
......@@ -69,12 +69,9 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_in_bytes(
void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
auto layouts = make_underlying_tensor_layout(
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout),
*(args.dst_layout));
*(args.src_layout), *(args.filter_layout), *(args.bias_layout),
*(args.z_layout), *(args.dst_layout));
auto ws = get_workspace_bundle(args.workspace.raw_ptr, args);
auto ws_src = ws.get(0);
auto ws_filter = ws.get(1);
......@@ -82,20 +79,27 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
void* ws_z = nullptr;
if (args.z_layout->ndim > 0)
ws_z = ws.get(4);
auto&& stream = cuda_stream(args.opr->handle());
auto nchw2nchw64 = [](const TensorND& src, void* raw_dptr) {
if (raw_dptr == nullptr)
// auto&& stream = cuda_stream(args.opr->handle());
auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) {
if (dst.raw_ptr == nullptr)
return;
auto relayout = args.handle->create_operator<RelayoutFormat>();
relayout->param() = RelayoutFormat::Param::Mode::NCHW_NCHW64;
Workspace dummy;
relayout->exec(src, dst, dummy);
};
auto nchw642nchw = [](const TensorND& src, void* raw_dptr) {
auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) {
auto relayout = args.handle->create_operator<RelayoutFormat>();
relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW;
Workspace dummy;
relayout->exec(src, dst, dummy);
};
// reformat src
nchw2nchw64(*(args.src_tensor), ws_src);
nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]});
// reformat filter
nchw2nchw64(*(args.filter_tensor), ws_filter);
nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]});
// reformat z
nchw2nchw64(*(args.z_tensor), ws_z);
nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]});
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]},
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]},
dst_{ws_dst, layouts[4]};
......@@ -109,22 +113,22 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
args.preprocessed_filter};
m_underlying_algo.exec(args);
// reformat dst
nchw642nchw(dst_, args.dst_tensor->raw_ptr);
nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout});
}
SmallVector<TensorLayout>
ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout(
const TensorLayout& src, const CanonizedFilterMeta& filter_meta,
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) const {
size_t n = src[0], ci = src[1], hi = src[2], wi = src[3];
size_t co = dst[1], ho = dst[2], wo = dst[3];
size_t fh = filter_meta.spatial[0], fw = filter_meta.spatial[1];
size_t fh = filter[2], fw = filter[3];
SmallVector<TensorLayout> rst;
rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype});
rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype});
rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype});
if (z.layout.ndim > 0) {
if (z.ndim > 0) {
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype});
} else {
rst.emplace_back(TensorLayout{});
......@@ -134,15 +138,13 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout(
}
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
void* raw_ptr, const SizeArgs& args) const {
size_t ws_size_src = args.src_layout->span().dist_byte();
size_t ws_size_filter = args.filter_layout->span().dist_byte();
size_t ws_size_dst = args.dst_layout->span().dist_byte();
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
auto layouts = make_underlying_tensor_layout(
*(args.src_layout), fm, *(args.bias_layout), *(args.z_layout),
*(args.dst_layout));
*(args.src_layout), *(args.filter_layout), *(args.bias_layout),
*(args.z_layout), *(args.dst_layout));
SizeArgs args_{args.opr,
layouts[0],
layouts[1],
......
......@@ -78,21 +78,26 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
return handle()->create_operator<RelayoutForward>()->exec(
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout});
}
if (param().mode == Param::Mode::NCHW_NCHW4 ||
bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 ||
param().mode == Param::Mode::NCHW64_NCHW) &&
(src_dtype.enumv() == DTypeEnum::QuantizedS4 ||
src_dtype.enumv() == DTypeEnum::Quantized4Asymm);
bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 ||
param().mode == Param::Mode::NCHW4_NCHW ||
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) {
param().mode == Param::Mode::NCHW_NCHW4_WEIGHT;
if (is_trans_4bits || is_nchw_nchw4) {
bool is_usable = relayout_format::RelayoutFormatFast::usable(
src.layout, dst.layout);
megdnn_assert(is_usable,
"RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) "
"to %s(%s)",
"RelayoutFormatFast kernel is not usable for "
"transforming %s(%s) to %s(%s).",
src.layout.to_string().c_str(), src.layout.dtype.name(),
dst.layout.to_string().c_str(), dst.layout.dtype.name());
relayout_format::RelayoutFormatFast::exec(src, dst,
cuda_stream(this->handle()),
param().mode, param().group);
} else {
return relayout_format::RelayoutFormatFast::exec(
src, dst, cuda_stream(this->handle()), param().mode,
param().group);
}
// fallback impls
TensorLayout exec_src, exec_dst, exec_workspace;
deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src,
exec_dst);
......@@ -100,7 +105,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
TensorND exec_dst_nd{dst.raw_ptr, exec_dst};
handle()->create_operator<RelayoutForward>()->exec(exec_src_nd,
exec_dst_nd);
}
}
size_t RelayoutFormatImpl::get_workspace_in_bytes(
......
......@@ -24,6 +24,8 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale,
scale = tensor_dtype.param<dtype::Quantized8Asymm>().scale;
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) {
scale = tensor_dtype.param<dtype::QuantizedS8>().scale;
} else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) {
scale = tensor_dtype.param<dtype::QuantizedS4>().scale;
}
}
......@@ -39,9 +41,8 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
cudaStream_t stream,
RelayoutFormat::Param::Mode mode,
int group) {
size_t ih = src.layout[2];
size_t iw = src.layout[3];
size_t hw = ih * iw;
auto&& stype = src.layout.dtype;
auto&& dtype = dst.layout.dtype;
float src_scale = 1.f;
float dst_scale = 1.f;
uint8_t src_zero_point = 0;
......@@ -51,22 +52,28 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src,
if (src.layout.dtype.enumv() == DTypeEnum::Uint8) {
src_zero_point = 128;
}
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4) {
if (hw % 4 == 0) {
relayout_format_cuda_nchw_nchw4<4>(src, dst, stream, src_scale,
if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4 ||
mode == RelayoutFormat::Param::Mode::NCHW_NCHW64) {
return relayout_format_cuda_nchw_nchwx(src, dst, stream, src_scale,
dst_scale, src_zero_point,
dst_zero_point, group);
} else {
relayout_format_cuda_nchw_nchw4<1>(src, dst, stream, src_scale,
} else if (mode == RelayoutFormat::Param::Mode::NCHW64_NCHW) {
megdnn_assert(group == 1,
"RelayoutFormat kernel only support transforming NCHW64 "
"to NCHW with group = 1(group:%d)",
group);
return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale,
dst_scale, src_zero_point,
dst_zero_point, group);
}
dst_zero_point);
} else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) {
relayout_format_cuda_nchw_nchw4_weight(src, dst, stream);
return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream);
} else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) {
relayout_format_cuda_nchw4_nchw(src, dst, stream, group);
return relayout_format_cuda_nchw4_nchw(src, dst, stream, group);
} else {
megdnn_throw("only support nchw_nchw4 nchw4_nchw layout_format");
megdnn_throw(
"only support nchw_nchw64/nchw64_nchw/nchw_nchw4/nchw4_nchw "
"layout_format");
}
}
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -19,13 +19,10 @@ namespace megdnn {
namespace cuda {
namespace relayout_format {
template <int pack_w = 1>
void relayout_format_cuda_nchw_nchw4(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const float src_scale = 1.f,
const float dst_scale = 1.f,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0,
void relayout_format_cuda_nchw_nchwx(
const TensorND& src, const TensorND& dst, const cudaStream_t& stream,
const float src_scale = 1.f, const float dst_scale = 1.f,
const uint8_t src_zero_point = 0, const uint8_t dst_zero_point = 0,
const int group = 1);
bool relayout_format_cuda_usable(const TensorLayout& src_layout,
......@@ -35,6 +32,13 @@ void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const int group);
void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const float src_scale = 1.f,
const float dst_scale = 1.f,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0);
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src,
const TensorND& dst,
const cudaStream_t& stream);
......
......@@ -110,6 +110,12 @@ MEGDNN_NORETURN void report_error(const char* msg);
template <typename T, size_t N>
struct array_wrapper {
T data[N];
MEGDNN_DEVICE __forceinline__ T& operator[](size_t pos) {
return reinterpret_cast<T&>(data[pos]);
}
MEGDNN_DEVICE __forceinline__ T const& operator[](size_t pos) const {
return reinterpret_cast<T const&>(data[pos]);
}
};
/*!
......@@ -207,12 +213,29 @@ struct CudaDTypeParamImpl<dt_quint4> : DTypeParamImpl<dt_quint4> {
CudaDTypeParamImpl(const DTypeParamImpl<dt_quint4>& param)
: CudaDTypeParamImpl(param.scale, param.zero_point) {}
__device__ uint8_t quantize(float in) const {
__device__ dt_quint4 quantize(float in) const {
float v = in * inv_scale;
v = roundf(v);
v = v + zero_point;
v = fmin(fmax(0.f, v), 15.f);
return static_cast<uint8_t>(v);
return static_cast<dt_quint4>(v);
}
};
template <>
struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> {
float inv_scale;
CudaDTypeParamImpl() = default;
CudaDTypeParamImpl(float scale)
: DTypeParamImpl<dt_qint4>(scale), inv_scale(1.0f / scale) {}
CudaDTypeParamImpl(const DTypeParamImpl<dt_qint4>& param)
: CudaDTypeParamImpl(param.scale) {}
__device__ dt_qint4 quantize(float in) const {
float v = in * inv_scale;
v = roundf(v);
v = fmin(fmax(-8.f, v), 7.f);
return static_cast<dt_qint4>(v);
}
};
......@@ -351,6 +374,110 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval,
return make_float4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z,
lval.w + rval.w);
}
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
unsigned out;
#if __CUDA_ARCH__ >= 750
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(out)
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6),
"r"(s7));
#else
#define CVT_SAT_S4_S32(r, bits) \
r = r <= -8 ? -8 : r; \
r = r > 7 ? 7 : r; \
r = (((unsigned)r & 0xf) << bits);
CVT_SAT_S4_S32(s0, 0)
CVT_SAT_S4_S32(s1, 4)
CVT_SAT_S4_S32(s2, 8)
CVT_SAT_S4_S32(s3, 12)
CVT_SAT_S4_S32(s4, 16)
CVT_SAT_S4_S32(s5, 20)
CVT_SAT_S4_S32(s6, 24)
CVT_SAT_S4_S32(s7, 28)
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;
#undef CVT_SAT_S4_S32
#endif
return reinterpret_cast<int const&>(out);
}
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
unsigned out;
#if __CUDA_ARCH__ >= 750
asm volatile(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
: "=r"(out)
: "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6),
"r"(s7));
#else
#define CVT_SAT_U4_S32(r, bits) \
r = r <= 0 ? 0 : r; \
r = r > 15 ? 15 : r; \
r = (((unsigned)r & 0xf) << bits);
CVT_SAT_U4_S32(s0, 0)
CVT_SAT_U4_S32(s1, 4)
CVT_SAT_U4_S32(s2, 8)
CVT_SAT_U4_S32(s3, 12)
CVT_SAT_U4_S32(s4, 16)
CVT_SAT_U4_S32(s5, 20)
CVT_SAT_U4_S32(s6, 24)
CVT_SAT_U4_S32(s7, 28)
out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;
#undef CVT_SAT_U4_S32
#endif
return reinterpret_cast<int const&>(out);
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(unsigned storage,
unsigned bits);
template <>
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<true>(unsigned storage,
unsigned bits) {
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf);
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask))
: (int)(result);
}
template <>
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<false>(unsigned storage,
unsigned bits) {
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf);
return (int)(result);
}
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<true>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<false>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
#endif
} // namespace cuda
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册