提交 c3a4b222 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cutlass impls for fused convolution reformat operation

GitOrigin-RevId: 02ef559c3f7367a3ee40d9f5017dcf3ece72ac0f
上级 5f44203d
......@@ -25,6 +25,8 @@ using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
/* ================= cutlass kernel wrapper for nchw32 layout ================
*/
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
......@@ -148,6 +150,131 @@ INST(true);
INST(false);
#undef INST
/* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const int32_t* /* d_bias */, const int8_t* /* d_z */,
int8_t* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter,
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \
using Convolution = cutlass::convolution::device::Convolution< \
int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \
cutlass::layout::TensorCxRSKx<32>, ElementOutput, \
cutlass::layout::TensorNCxHWx<4>, int32_t, \
cutlass::layout::TensorNCxHWx<4>, int32_t, \
cutlass::convolution::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \
2, 16, 16, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \
param.sw, param.ph, param.pw, 1, 1}; \
return cutlass_convolution_wrapper<Convolution>( \
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementBias = int32_t;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationReluClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationHSwishClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST(true);
INST(false);
#undef INST
/* ================ cutlass kernel wrapper for nchw4 layout ================= */
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
......@@ -275,6 +402,7 @@ INST(true);
INST(false);
#undef INST
/* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
......@@ -401,4 +529,131 @@ void megdnn::cuda::cutlass_wrapper::
INST(true);
INST(false);
#undef INST
/* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const int32_t* /* d_bias */, const int8_t* /* d_z */,
int8_t* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
const int8_t* d_src, const int8_t* d_filter,
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
using Convolution = cutlass::convolution::device::Convolution< \
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
cutlass::layout::TensorNCxHWx<32>, int32_t, \
cutlass::layout::TensorNCxHWx<32>, int32_t, \
cutlass::convolution::ConvType::kConvolution, \
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \
2, 4, aligned_, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \
param.sw, param.ph, param.pw, 1, 1}; \
return cutlass_convolution_wrapper<Convolution>( \
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementBias = int32_t;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationReluClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationHSwishClamp<
ElementOutput, 4, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST(true);
INST(false);
#undef INST
// vim: syntax=cuda.doxygen
......@@ -40,6 +40,15 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
......@@ -58,6 +67,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
......
......@@ -35,10 +35,23 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available(
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout),
param.format))
return false;
if (param.format != Format::NCHW32)
if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4)
return false;
UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 32,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW32) {
co = args.dst_layout->operator[](1) * 32;
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
co = args.dst_layout->operator[](1) * 4;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
// TODO support group conv
available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation
......@@ -84,8 +97,21 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 32,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW32) {
co = args.dst_layout->operator[](1) * 32;
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
co = args.dst_layout->operator[](1) * 4;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
auto&& stream = cuda_stream(args.opr->handle());
int8_t* filter_ptr = nullptr;
......@@ -137,33 +163,79 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
if (fh == 1 && fw == 1) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<
false>(args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
if (param.format == Format::NCHW32) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<
false>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4<
false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
} else {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
if (param.format == Format::NCHW32) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<
true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4<
true>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
z_dev_ptr,
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
}
after_kernel_launch();
}
std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string(
......@@ -189,8 +261,21 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess(
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout),
param);
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 32,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW32) {
co = args.dst_layout->operator[](1) * 32;
} else {
megdnn_assert(param.format == Format::NCHW32_NCHW4);
co = args.dst_layout->operator[](1) * 4;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()};
src.init_contiguous_stride();
TensorLayout dst = src;
......
......@@ -208,6 +208,24 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
stream);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32<
false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
} else {
if (param.format == Format::NCHW4) {
......@@ -246,6 +264,24 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32<
true>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
}
after_kernel_launch();
......
......@@ -1232,6 +1232,73 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW4_NCHW) {
run({{16, 4, 46, 80, 4}, {4, 4, 3, 3, 4}, {1, 4, 1, 1}});
}
TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW4_NCHW32) {
require_compute_capability(6, 1);
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle_cuda());
UniformIntRNG int_rng{-3, 3};
UniformIntRNG bias_rng{-50, 50};
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW4_NCHW32;
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY;
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"));
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f))
.set_dtype(1, dtype::QuantizedS8(1.9980927f))
.set_dtype(2, dtype::QuantizedS32(1.9980618f * 1.9980927f))
.set_dtype(3, dtype::QuantizedS8(1.9980618f))
.set_dtype(4, dtype::QuantizedS8(1.9980618f))
.set_rng(0, &int_rng)
.set_rng(1, &int_rng)
.set_rng(2, &bias_rng)
.set_rng(3, &int_rng)
.set_param(param);
auto run = [&](const TensorShapeArray& shapes) {
checker.execs({shapes[0], shapes[1], shapes[2], {}, {}});
};
run({{16, 4, 23, 40, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}});
run({{16, 4, 92, 160, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}});
run({{16, 4, 46, 80, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}});
}
#if CUDA_VERSION >= 10020
TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) {
require_compute_capability(7, 5);
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle_cuda());
UniformIntRNG int_rng{-3, 3};
UniformIntRNG bias_rng{-50, 50};
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW32_NCHW4;
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY;
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<
ConvBiasForward>(
ConvBias::algo_name<ConvBias::DirectParam>(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64",
ConvBias::DirectParam{})
.c_str()));
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f))
.set_dtype(1, dtype::QuantizedS8(1.9980927f))
.set_dtype(2, dtype::QuantizedS32(1.9980618f * 1.9980927f))
.set_dtype(3, dtype::QuantizedS8(1.9980618f))
.set_dtype(4, dtype::QuantizedS8(1.9980618f))
.set_rng(0, &int_rng)
.set_rng(1, &int_rng)
.set_rng(2, &bias_rng)
.set_rng(3, &int_rng)
.set_param(param);
auto run = [&](const TensorShapeArray& shapes) {
checker.execs({shapes[0], shapes[1], shapes[2], {}, {}});
};
run({{16, 2, 23, 40, 32}, {20, 2, 3, 3, 32}, {1, 5, 1, 1, 4}});
run({{16, 1, 92, 160, 32}, {24, 1, 3, 3, 32}, {1, 6, 1, 1, 4}});
run({{16, 2, 46, 80, 32}, {4, 2, 3, 3, 32}, {1, 1, 1, 1, 4}});
}
#endif
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) {
require_compute_capability(6, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册