提交 5f44203d 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add a cutlass impl for fusing convolution and dimshuffle

GitOrigin-RevId: 3fc6faef01202867f54206367d32ab01659326d0
上级 61f917fb
......@@ -275,4 +275,130 @@ INST(true);
INST(false);
#undef INST
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const float* /* d_bias */, const float* /* d_z */,
float* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* d_src, const int8_t* d_filter,
const float* d_bias, const float* d_z, float* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
using Convolution = cutlass::convolution::device::Convolution< \
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
cutlass::layout::TensorNCHW, float, \
cutlass::layout::TensorNCHW, int32_t, \
cutlass::convolution::ConvType::kConvolution, \
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \
2, 4, aligned_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAdd>; \
typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \
param.sw, param.ph, param.pw, 1, 1}; \
return cutlass_convolution_wrapper<Convolution>( \
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementBias = float;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombination<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationRelu<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationHSwish<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const float* d_bias, const float* d_z, float* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST(true);
INST(false);
#undef INST
// vim: syntax=cuda.doxygen
......@@ -22,8 +22,11 @@ using GemmCoord = cutlass::gemm::GemmCoord;
template <typename Convolution>
void cutlass_convolution_wrapper(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
......@@ -46,6 +49,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* d_src, const int8_t* d_filter, const float* d_bias,
const float* d_z, float* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
......
......@@ -32,10 +32,26 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout),
param.format))
return false;
if (param.format != Format::NCHW4)
if (param.format != Format::NCHW4 && param.format != Format::NCHW4_NCHW &&
param.format != Format::NCHW4_NCHW32)
return false;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
......@@ -46,9 +62,11 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
bias_dtype = args.bias_layout->dtype,
dst_dtype = args.dst_layout->dtype;
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS8 &&
bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS8);
filter_dtype.enumv() == DTypeEnum::QuantizedS8);
available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS8) ||
(bias_dtype.enumv() == DTypeEnum::Float32 &&
dst_dtype.enumv() == DTypeEnum::Float32);
// TODO: support dialtion
available &= dh == 1 && dw == 1;
// only support sm_61 or later, platform should have fast native int8
......@@ -81,8 +99,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
auto&& stream = cuda_stream(args.opr->handle());
int8_t* filter_ptr = nullptr;
......@@ -115,47 +148,107 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
int8_t* z_dev_ptr = nullptr;
float gamma = 0.0;
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale;
float alpha = src_scale * filter_scale;
float beta = 1.f;
float dst_scale = 1.f;
if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) {
megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8);
float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>()
.scale,
dst_scale =
args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
alpha /= dst_scale, beta = bias_scale / dst_scale;
}
float gamma = 0.f;
if (args.z_layout->ndim > 0) {
z_dev_ptr = args.z_tensor->compatible_ptr<int8_t>();
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale;
gamma = z_scale / dst_scale;
gamma = 1.f;
if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8) {
megdnn_assert(args.dst_layout->dtype.enumv() ==
DTypeEnum::QuantizedS8);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>()
.scale;
gamma = z_scale / dst_scale;
}
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
if (fh == 1 && fw == 1) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<false>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
if (param.format == Format::NCHW4) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<
false>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<float>(),
args.z_tensor->compatible_ptr<float>(),
args.dst_tensor->compatible_ptr<float>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
}
} else {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
if (param.format == Format::NCHW4) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<
true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<true>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<float>(),
args.z_tensor->compatible_ptr<float>(),
args.dst_tensor->compatible_ptr<float>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
}
}
after_kernel_launch();
}
size_t ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::
......@@ -174,8 +267,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess(
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()};
src.init_contiguous_stride();
TensorLayout dst = src;
......
......@@ -19,25 +19,28 @@ using namespace cutlass_wrapper;
template <typename Convolution>
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream) {
typename Convolution::TensorRefSrc tensor_src{
const_cast<int8_t*>(d_src),
const_cast<typename Convolution::ElementSrc*>(d_src),
Convolution::LayoutSrc::packed({conv_param.n(), conv_param.hi(),
conv_param.wi(), conv_param.ci()})};
typename Convolution::TensorRefFilter tensor_filter{
const_cast<int8_t*>(d_filter),
const_cast<typename Convolution::ElementFilter*>(d_filter),
Convolution::LayoutFilter::packed({conv_param.co(), conv_param.fh(),
conv_param.fw(),
conv_param.ci()})};
typename Convolution::TensorRefBias tensor_bias{
const_cast<int32_t*>(d_bias),
const_cast<typename Convolution::ElementBias*>(d_bias),
Convolution::LayoutBias::packed({1, 1, 1, conv_param.co()})};
typename Convolution::TensorRefDst tensor_z{
const_cast<int8_t*>(d_z),
const_cast<typename Convolution::ElementDst*>(d_z),
Convolution::LayoutDst::packed({conv_param.n(), conv_param.ho(),
conv_param.wo(), conv_param.co()})};
typename Convolution::TensorRefDst tensor_dst{
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册