提交 66f70578 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dnn/cuda): add convolution with i8 input and i4 output

GitOrigin-RevId: 10512645d5d5ac3d985720788760bf8a3855c1f1
上级 6d686ff2
...@@ -37,9 +37,10 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) ...@@ -37,9 +37,10 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py ../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py
./$^ --type cuda $@ ./$^ --type cuda $@
../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py ../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py cutlass_generator/generator.py
./gen_cuda_conv_bias_kern_impls.py --type dp4a $@ ./gen_cuda_conv_bias_kern_impls.py --type dp4a $@
./gen_cutlass_conv_bias_kern_impls.py --type dp4a $@ ./gen_cutlass_conv_bias_kern_impls.py --type dp4a $@
python3 ./cutlass_generator/generator.py --operations all --type simt $@
../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py
./gen_cuda_conv_bias_kern_impls.py --type imma $@ ./gen_cuda_conv_bias_kern_impls.py --type imma $@
......
...@@ -43,6 +43,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -43,6 +43,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'), 'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
...@@ -99,6 +100,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -99,6 +100,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'), 'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
......
...@@ -65,7 +65,8 @@ void do_check_exec_common( ...@@ -65,7 +65,8 @@ void do_check_exec_common(
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[3] == 1);
} else if (opr->param().format == param::ConvBias::Format::NHWC) { } else if (param().format == param::ConvBias::Format::NHWC ||
param().format == param::ConvBias::Format::NCHW4_NHWC) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[1] == 1);
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
......
...@@ -369,6 +369,7 @@ void make_canonized_filter_meta_nchwx( ...@@ -369,6 +369,7 @@ void make_canonized_filter_meta_nchwx(
param.format == Param::Format::NCHW8 || param.format == Param::Format::NCHW8 ||
param.format == Param::Format::NCHW32 || param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW4_NCHW || param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4 || param.format == Param::Format::NCHW32_NCHW4 ||
param.format == Param::Format::NCHW64); param.format == Param::Format::NCHW64);
...@@ -498,6 +499,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( ...@@ -498,6 +499,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
} }
} else if (param().format == Param::Format::NCHW4 || } else if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4_NCHW || param().format == Param::Format::NCHW4_NCHW ||
param().format == Param::Format::NCHW4_NHWC ||
param().format == Param::Format::NCHW4_NCHW32) { param().format == Param::Format::NCHW4_NCHW32) {
make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter,
param(), ret); param(), ret);
...@@ -547,7 +549,12 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, ...@@ -547,7 +549,12 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
src.enumv() == DTypeEnum::Quantized4Asymm) { src.enumv() == DTypeEnum::Quantized4Asymm) {
supported_dst_dtype.push_back( supported_dst_dtype.push_back(
dtype::QuantizedS32(mul_scale(src, filter))); dtype::QuantizedS32(mul_scale(src, filter)));
if (dst.valid() && dst.enumv() == src.enumv()) { bool cond_dst =
dst.valid() && (dst.enumv() == src.enumv() ||
((dst.enumv() == DTypeEnum::QuantizedS4 ||
dst.enumv() == DTypeEnum::Quantized4Asymm) &&
src.enumv() == DTypeEnum::QuantizedS8));
if (cond_dst) {
supported_dst_dtype.push_back(dst); supported_dst_dtype.push_back(dst);
} }
if (src.enumv() == DTypeEnum::QuantizedS8) { if (src.enumv() == DTypeEnum::QuantizedS8) {
...@@ -612,6 +619,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -612,6 +619,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(param().format == Param::Format::NHWCD4 || megdnn_assert(param().format == Param::Format::NHWCD4 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4_NCHW || param().format == Param::Format::NCHW4_NCHW ||
param().format == Param::Format::NCHW4_NHWC ||
param().format == Param::Format::NCHW4_NCHW32 || param().format == Param::Format::NCHW4_NCHW32 ||
param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT || param().format == Param::Format::NCHW44_DOT ||
...@@ -879,6 +887,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -879,6 +887,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
cflt.stride[0], cflt.padding[0]); cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]); cflt.stride[1], cflt.padding[1]);
} else if (param().format == Param::Format::NCHW4_NHWC) {
megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
src.ndim);
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4,
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
dst.ndim = 4;
dst[0] = src[0];
dst[1] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[2] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
auto oc = cflt.ocpg * cflt.group;
dst[3] = oc;
} else if (param().format == Param::Format::NCHW4_NCHW32) { } else if (param().format == Param::Format::NCHW4_NCHW32) {
megdnn_assert(src.ndim == 5, megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
......
...@@ -35,6 +35,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( ...@@ -35,6 +35,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
return false; return false;
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm)
return false;
if (args.src_layout->dtype == args.filter_layout->dtype && if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) { args.src_layout->dtype == dtype::BFloat16()) {
return false; return false;
......
...@@ -911,4 +911,140 @@ void megdnn::cuda::cutlass_wrapper:: ...@@ -911,4 +911,140 @@ void megdnn::cuda::cutlass_wrapper::
INST(true); INST(true);
#undef INST #undef INST
/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */
#if MEGDNN_TEGRA_X1
template <bool signedness>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const int32_t* /* d_bias */, const int8_t* /* d_z */,
int8_t* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* delta */,
float /* theta */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, int /* stages */,
cudaStream_t /* stream */) {}
#else
template <bool signedness>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
const int8_t* d_src, const int8_t* d_filter,
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float delta, float theta, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, stages_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stages_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
using Convolution = cutlass::conv::device::Convolution< \
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
cutlass::layout::TensorNHWC, int32_t, \
cutlass::layout::TensorNHWC, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
stages_, 4, aligned_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
d_src, d_filter, d_bias, \
reinterpret_cast<const ElementOutput*>(d_z), \
reinterpret_cast<ElementOutput*>(d_dst), workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = cutlass::integer_subbyte<4, signedness>;
using ElementAccumulator = int32_t;
using ElementBias = int32_t;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
ElementOutput, 8, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma,
delta + theta};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationReluClamp<
ElementOutput, 8, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma,
0, delta, theta};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationHSwishClamp<
ElementOutput, 8, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma,
scale, detla, theta};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(signedness) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float delta, float theta, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, int stages, \
cudaStream_t stream);
INST(true);
INST(false);
#undef INST
// vim: syntax=cuda.doxygen // vim: syntax=cuda.doxygen
...@@ -94,6 +94,15 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( ...@@ -94,6 +94,15 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream); const GemmCoord& warp_shape, cudaStream_t stream);
template <bool signedness>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float delta, float theta,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, int stages, cudaStream_t stream);
} // namespace cutlass_wrapper } // namespace cutlass_wrapper
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
/**
* \file
* dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename Convolution>
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, typename Convolution::ExtraParam extra_param) {
typename Convolution::TensorRefSrc tensor_src{
const_cast<typename Convolution::ElementSrc*>(d_src),
Convolution::LayoutSrc::packed(
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})};
typename Convolution::TensorRefFilter tensor_filter{
const_cast<typename Convolution::ElementFilter*>(d_filter),
Convolution::LayoutFilter::packed(
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})};
typename Convolution::TensorRefBias tensor_bias{
const_cast<typename Convolution::ElementBias*>(d_bias),
Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})};
typename Convolution::TensorRefDst tensor_z{
const_cast<typename Convolution::ElementDst*>(d_z),
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::TensorRefDst tensor_dst{
d_dst,
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::Arguments arguments{conv_param,
tensor_src.non_const_ref(),
tensor_filter.non_const_ref(),
tensor_bias.non_const_ref(),
tensor_z.non_const_ref(),
tensor_dst.non_const_ref(),
epilogue,
{},
{},
extra_param};
Convolution conv_op;
cutlass_check(conv_op.initialize(arguments, workspace));
cutlass_check(conv_op(stream));
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
...@@ -37,27 +37,40 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( ...@@ -37,27 +37,40 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
if (!check_bias_share_in_channel(*(args.bias_layout), if (!check_bias_share_in_channel(*(args.bias_layout),
param.format)) param.format))
return false; return false;
if (param.format == Format::NCHW4_NCHW32) { bool valid_format = param.format == Format::NCHW4_NCHW32 &&
if (m_algo_param.threadblock_m % 32 != 0) m_algo_param.threadblock_m % 32 == 0;
return false; valid_format |= param.format == Format::NCHW4_NCHW &&
} else if (param.format != Format::NCHW4_NCHW && args.bias_layout->dtype.enumv() == DTypeEnum::Float32 &&
param.format != Format::NCHW4) args.dst_layout->dtype.enumv() == DTypeEnum::Float32;
return false; valid_format |=
param.format == Format::NCHW4_NHWC &&
args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 &&
(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm);
valid_format |= param.format == Format::NCHW4;
if (!valid_format) return false;
size_t n = args.src_layout->operator[](0), size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 4, ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2), hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3); wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co; size_t co;
size_t dst_spatial_pos;
if (param.format == Format::NCHW4) { if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4; co = args.dst_layout->operator[](1) * 4;
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NCHW) { } else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1); co = args.dst_layout->operator[](1);
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NHWC) {
co = args.dst_layout->operator[](3);
dst_spatial_pos = 1;
} else { } else {
megdnn_assert(param.format == Format::NCHW4_NCHW32); megdnn_assert(param.format == Format::NCHW4_NCHW32);
dst_spatial_pos = 2;
co = args.dst_layout->operator[](1) * 32; co = args.dst_layout->operator[](1) * 32;
} }
size_t ho = args.dst_layout->operator[](dst_spatial_pos),
wo = args.dst_layout->operator[](dst_spatial_pos + 1);
UNPACK_CONV_PARAMETER(fm, param); UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR MARK_USED_VAR
// TODO support group conv // TODO support group conv
...@@ -72,7 +85,9 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( ...@@ -72,7 +85,9 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS8); filter_dtype.enumv() == DTypeEnum::QuantizedS8);
available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 && available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS8) || (dst_dtype.enumv() == DTypeEnum::QuantizedS8 ||
dst_dtype.enumv() == DTypeEnum::QuantizedS4 ||
dst_dtype.enumv() == DTypeEnum::Quantized4Asymm)) ||
(bias_dtype.enumv() == DTypeEnum::Float32 && (bias_dtype.enumv() == DTypeEnum::Float32 &&
dst_dtype.enumv() == DTypeEnum::Float32); dst_dtype.enumv() == DTypeEnum::Float32);
// TODO: support dialtion // TODO: support dialtion
...@@ -111,17 +126,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -111,17 +126,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
ci = args.src_layout->operator[](1) * 4, ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2), hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3); wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2), size_t co, dst_spatial_pos;
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) { if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4; co = args.dst_layout->operator[](1) * 4;
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NCHW) { } else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1); co = args.dst_layout->operator[](1);
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NHWC) {
co = args.dst_layout->operator[](3);
dst_spatial_pos = 1;
} else { } else {
megdnn_assert(param.format == Format::NCHW4_NCHW32); megdnn_assert(param.format == Format::NCHW4_NCHW32);
dst_spatial_pos = 2;
co = args.dst_layout->operator[](1) * 32; co = args.dst_layout->operator[](1) * 32;
} }
size_t ho = args.dst_layout->operator[](dst_spatial_pos),
wo = args.dst_layout->operator[](dst_spatial_pos + 1);
UNPACK_CONV_PARAMETER(fm, param); UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR MARK_USED_VAR
auto&& stream = cuda_stream(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle());
...@@ -161,136 +182,107 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -161,136 +182,107 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
float beta = 1.f; float beta = 1.f;
float dst_scale = 1.f; float dst_scale = 1.f;
if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) {
megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8); megdnn_assert(args.dst_layout->dtype.category() ==
DTypeCategory::QUANTIZED);
float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>()
.scale, .scale;
dst_scale = dst_scale = get_scale(args.dst_layout->dtype);
args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
alpha /= dst_scale, beta = bias_scale / dst_scale; alpha /= dst_scale, beta = bias_scale / dst_scale;
} }
float gamma = 0.f; float gamma = 0.f;
if (args.z_layout->ndim > 0) { if (args.z_layout->ndim > 0) {
gamma = 1.f; gamma = 1.f;
if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8) { if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) {
megdnn_assert(args.dst_layout->dtype.enumv() == megdnn_assert(args.dst_layout->dtype.category() ==
DTypeEnum::QuantizedS8); DTypeCategory::QUANTIZED);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>() float z_scale = get_scale(args.z_layout->dtype);
.scale;
gamma = z_scale / dst_scale; gamma = z_scale / dst_scale;
} }
} }
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
if (fh == 1 && fw == 1) { bool nonunity_kernel = !(fh == 1 && fw == 1);
if (param.format == Format::NCHW4) { #define DISPATCH(_nonunity_kernel) \
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< if (nonunity_kernel == _nonunity_kernel) { \
false>( cb(_nonunity_kernel) \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
} else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<float>(),
args.z_tensor->compatible_ptr<float>(),
args.dst_tensor->compatible_ptr<float>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32<
false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(),
args.z_tensor->compatible_ptr<int8_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream);
} }
} else {
if (param.format == Format::NCHW4) { if (param.format == Format::NCHW4) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< #define cb(_nonunity_kernel) \
true>( cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, _nonunity_kernel>( \
args.bias_tensor->compatible_ptr<int32_t>(), args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.z_tensor->compatible_ptr<int8_t>(), args.bias_tensor->compatible_ptr<int32_t>(), \
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, args.z_tensor->compatible_ptr<int8_t>(), \
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, nonlinear_mode, alpha, beta, gamma, dst_scale, \
m_algo_param.threadblock_n, cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_k}, m_algo_param.threadblock_n, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.threadblock_k}, \
m_algo_param.warp_n, cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_k}, m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream); m_algo_param.stage, stream);
DISPATCH(true);
DISPATCH(false);
#undef cb
} else if (param.format == Format::NCHW4_NCHW) { } else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper:: #define cb(_nonunity_kernel) \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<true>( cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
args.src_tensor->compatible_ptr<int8_t>(), _nonunity_kernel>( \
filter_ptr, args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<float>(), args.bias_tensor->compatible_ptr<float>(), \
args.z_tensor->compatible_ptr<float>(), args.z_tensor->compatible_ptr<float>(), \
args.dst_tensor->compatible_ptr<float>(), nullptr, args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \
kern_param, nonlinear_mode, alpha, beta, gamma, nonlinear_mode, alpha, beta, gamma, dst_scale, \
dst_scale, cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
cutlass_wrapper::GemmCoord{ m_algo_param.threadblock_n, \
m_algo_param.threadblock_m, m_algo_param.threadblock_k}, \
m_algo_param.threadblock_n, cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.threadblock_k}, m_algo_param.warp_n, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_k}, \
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream); m_algo_param.stage, stream);
DISPATCH(true);
DISPATCH(false);
#undef cb
} else if (param.format == Format::NCHW4_NHWC) {
#define cb(_nonunity_kernel) \
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \
_nonunity_kernel>( \
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<int32_t>(), \
reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
m_algo_param.threadblock_n, \
m_algo_param.threadblock_k}, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.warp_n, \
m_algo_param.warp_k}, \
m_algo_param.stage, stream);
cb(true);
#undef cb
} else { } else {
megdnn_assert(param.format == Format::NCHW4_NCHW32); megdnn_assert(param.format == Format::NCHW4_NCHW32);
cutlass_wrapper:: #define cb(_nonunity_kernel) \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< cutlass_wrapper:: \
true>( do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \
args.src_tensor->compatible_ptr<int8_t>(), _nonunity_kernel>( \
filter_ptr, args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \
args.bias_tensor->compatible_ptr<int32_t>(), args.bias_tensor->compatible_ptr<int32_t>(), \
args.z_tensor->compatible_ptr<int8_t>(), args.z_tensor->compatible_ptr<int8_t>(), \
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \
kern_param, nonlinear_mode, alpha, beta, gamma, kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \
dst_scale, cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \
cutlass_wrapper::GemmCoord{ m_algo_param.threadblock_n, \
m_algo_param.threadblock_m, m_algo_param.threadblock_k}, \
m_algo_param.threadblock_n, cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \
m_algo_param.threadblock_k}, m_algo_param.warp_n, \
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_k}, \
m_algo_param.warp_n,
m_algo_param.warp_k},
m_algo_param.stage, stream); m_algo_param.stage, stream);
} DISPATCH(true);
DISPATCH(false);
#undef cb
#undef DISPATCH
} }
after_kernel_launch(); after_kernel_launch();
} }
...@@ -315,17 +307,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( ...@@ -315,17 +307,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess(
ci = args.src_layout->operator[](1) * 4, ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2), hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3); wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2), size_t co, dst_spatial_pos;
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) { if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4; co = args.dst_layout->operator[](1) * 4;
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NCHW) { } else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1); co = args.dst_layout->operator[](1);
dst_spatial_pos = 2;
} else if (param.format == Format::NCHW4_NHWC) {
co = args.dst_layout->operator[](3);
dst_spatial_pos = 1;
} else { } else {
megdnn_assert(param.format == Format::NCHW4_NCHW32); megdnn_assert(param.format == Format::NCHW4_NCHW32);
dst_spatial_pos = 2;
co = args.dst_layout->operator[](1) * 32; co = args.dst_layout->operator[](1) * 32;
} }
size_t ho = args.dst_layout->operator[](dst_spatial_pos),
wo = args.dst_layout->operator[](dst_spatial_pos + 1);
UNPACK_CONV_PARAMETER(fm, param); UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR MARK_USED_VAR
TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()};
......
/**
* \file
* dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/convolution/device/convolution.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename Convolution>
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper(
const typename Convolution::ElementSrc* d_src,
const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream, typename Convolution::ExtraParam extra_param) {
typename Convolution::TensorRefSrc tensor_src{
const_cast<typename Convolution::ElementSrc*>(d_src),
Convolution::LayoutSrc::packed(
{conv_param.N, conv_param.H, conv_param.W, conv_param.C})};
typename Convolution::TensorRefFilter tensor_filter{
const_cast<typename Convolution::ElementFilter*>(d_filter),
Convolution::LayoutFilter::packed(
{conv_param.K, conv_param.R, conv_param.S, conv_param.C})};
typename Convolution::TensorRefBias tensor_bias{
const_cast<typename Convolution::ElementBias*>(d_bias),
Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})};
typename Convolution::TensorRefDst tensor_z{
const_cast<typename Convolution::ElementDst*>(d_z),
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::TensorRefDst tensor_dst{
d_dst,
Convolution::LayoutDst::packed(
{conv_param.N, conv_param.P, conv_param.Q, conv_param.K})};
typename Convolution::Arguments arguments{conv_param,
tensor_src.non_const_ref(),
tensor_filter.non_const_ref(),
tensor_bias.non_const_ref(),
tensor_z.non_const_ref(),
tensor_dst.non_const_ref(),
epilogue,
{},
{},
extra_param};
Convolution conv_op;
cutlass_check(conv_op.initialize(arguments, workspace));
cutlass_check(conv_op(stream));
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
../implicit_gemm_conv_bias_cutlass_wrapper.cuinl
\ No newline at end of file
...@@ -159,6 +159,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -159,6 +159,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
filter_meta.format == Format::NCHW44_DOT || filter_meta.format == Format::NCHW44_DOT ||
filter_meta.format == Format::NCHW4 || filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW || filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NHWC ||
filter_meta.format == Format::NCHW4_NCHW32 || filter_meta.format == Format::NCHW4_NCHW32 ||
filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32 || filter_meta.format == Format::NCHW32 ||
...@@ -182,9 +183,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -182,9 +183,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start], auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start],
IW = src.layout.shape[spatial_start + 1]; IW = src.layout.shape[spatial_start + 1];
auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1]; auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1];
auto OC = dst.layout.shape[channel_pos], size_t OC, OH, OW;
if (filter_meta.format == Format::NCHW4_NHWC) {
OC = dst.layout.shape[3], OH = dst.layout.shape[1],
OW = dst.layout.shape[2];
} else {
OC = dst.layout.shape[channel_pos],
OH = dst.layout.shape[spatial_start], OH = dst.layout.shape[spatial_start],
OW = dst.layout.shape[spatial_start + 1]; OW = dst.layout.shape[spatial_start + 1];
}
if (filter_meta.format == Format::NCHW4 || if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::CHWN4 || filter_meta.format == Format::CHWN4 ||
...@@ -206,6 +213,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -206,6 +213,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
if (filter_meta.format == Format::NCHW || if (filter_meta.format == Format::NCHW ||
filter_meta.format == Format::NCHW4 || filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW || filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NHWC ||
filter_meta.format == Format::NCHW4_NCHW32 || filter_meta.format == Format::NCHW4_NCHW32 ||
filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32 || filter_meta.format == Format::NCHW32 ||
...@@ -343,6 +351,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -343,6 +351,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
h * layout.stride[2] + w * layout.stride[3] + h * layout.stride[2] + w * layout.stride[3] +
(c & 0b11) * layout.stride[4]; (c & 0b11) * layout.stride[4];
} }
} else if (filter_meta.format == Format::NCHW4_NHWC) {
if (is_output) {
return n * layout.stride[0] + h * layout.stride[1] +
w * layout.stride[2] + c * layout.stride[3];
} else {
return n * layout.stride[0] + (c / 4) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0b11) * layout.stride[4];
}
} else if (filter_meta.format == Format::NCHW4_NCHW32) { } else if (filter_meta.format == Format::NCHW4_NCHW32) {
if (is_output) { if (is_output) {
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
...@@ -370,6 +387,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -370,6 +387,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
size_t fh, size_t fw) { size_t fh, size_t fw) {
if (filter_meta.format == Format::NCHW4 || if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW || filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NHWC ||
filter_meta.format == Format::NCHW4_NCHW32) { filter_meta.format == Format::NCHW4_NCHW32) {
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) / 4 * FS_IC * 4 + (ic - ic0) / 4 * FS_IC * 4 +
...@@ -695,6 +713,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -695,6 +713,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
case param::Convolution::Format::NHWC: case param::Convolution::Format::NHWC:
case param::Convolution::Format::NCHW4: case param::Convolution::Format::NCHW4:
case param::Convolution::Format::NCHW4_NCHW: case param::Convolution::Format::NCHW4_NCHW:
case param::Convolution::Format::NCHW4_NHWC:
case param::Convolution::Format::NCHW4_NCHW32: case param::Convolution::Format::NCHW4_NCHW32:
case param::Convolution::Format::NCHW8: case param::Convolution::Format::NCHW8:
case param::Convolution::Format::NCHW32: case param::Convolution::Format::NCHW32:
...@@ -820,6 +839,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -820,6 +839,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
BIAS_ADD_CHWNx(4); BIAS_ADD_CHWNx(4);
break; break;
} }
case Format::NCHW4_NHWC:
case Format::NHWC: { case Format::NHWC: {
int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] * int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] *
dst.layout.shape[2]; dst.layout.shape[2];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册