提交 5f44203d 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add a cutlass impl for fusing convolution and dimshuffle

GitOrigin-RevId: 3fc6faef01202867f54206367d32ab01659326d0
上级 61f917fb
...@@ -275,4 +275,130 @@ INST(true); ...@@ -275,4 +275,130 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const float* /* d_bias */, const float* /* d_z */,
float* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* d_src, const int8_t* d_filter,
const float* d_bias, const float* d_z, float* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, aligned_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \
using Convolution = cutlass::convolution::device::Convolution< \
int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \
cutlass::layout::TensorCxRSKx<4>, ElementOutput, \
cutlass::layout::TensorNCHW, float, \
cutlass::layout::TensorNCHW, int32_t, \
cutlass::convolution::ConvType::kConvolution, \
cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::convolution::threadblock:: \
ConvolutionNCxHWxThreadblockSwizzle< \
cutlass::convolution::ConvType::kConvolution>, \
2, 4, aligned_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAdd>; \
typename Convolution::ConvolutionParameter conv_param{ \
param.n, param.ci, param.co, param.hi, param.wi, \
param.fh, param.fw, param.ho, param.wo, param.sh, \
param.sw, param.ph, param.pw, 1, 1}; \
return cutlass_convolution_wrapper<Convolution>( \
d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementBias = float;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombination<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationRelu<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationHSwish<
ElementOutput, 1, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const float* d_bias, const float* d_z, float* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST(true);
INST(false);
#undef INST
// vim: syntax=cuda.doxygen // vim: syntax=cuda.doxygen
...@@ -22,8 +22,11 @@ using GemmCoord = cutlass::gemm::GemmCoord; ...@@ -22,8 +22,11 @@ using GemmCoord = cutlass::gemm::GemmCoord;
template <typename Convolution> template <typename Convolution>
void cutlass_convolution_wrapper( void cutlass_convolution_wrapper(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, const typename Convolution::ElementSrc* d_src,
const int8_t* d_z, int8_t* d_dst, int* workspace, const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param, typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue, typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream); cudaStream_t stream);
...@@ -46,6 +49,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( ...@@ -46,6 +49,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4(
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream); cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw(
const int8_t* d_src, const int8_t* d_filter, const float* d_bias,
const float* d_z, float* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
} // namespace cutlass_wrapper } // namespace cutlass_wrapper
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -32,10 +32,26 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( ...@@ -32,10 +32,26 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout),
param.format)) param.format))
return false; return false;
if (param.format != Format::NCHW4) if (param.format != Format::NCHW4 && param.format != Format::NCHW4_NCHW &&
param.format != Format::NCHW4_NCHW32)
return false; return false;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), size_t n = args.src_layout->operator[](0),
param); ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
// TODO support group conv // TODO support group conv
available &= param.sparse == Sparse::DENSE; available &= param.sparse == Sparse::DENSE;
// mode must be cross correlation // mode must be cross correlation
...@@ -46,9 +62,11 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( ...@@ -46,9 +62,11 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
bias_dtype = args.bias_layout->dtype, bias_dtype = args.bias_layout->dtype,
dst_dtype = args.dst_layout->dtype; dst_dtype = args.dst_layout->dtype;
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
filter_dtype.enumv() == DTypeEnum::QuantizedS8 && filter_dtype.enumv() == DTypeEnum::QuantizedS8);
bias_dtype.enumv() == DTypeEnum::QuantizedS32 && available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
dst_dtype.enumv() == DTypeEnum::QuantizedS8); dst_dtype.enumv() == DTypeEnum::QuantizedS8) ||
(bias_dtype.enumv() == DTypeEnum::Float32 &&
dst_dtype.enumv() == DTypeEnum::Float32);
// TODO: support dialtion // TODO: support dialtion
available &= dh == 1 && dw == 1; available &= dh == 1 && dw == 1;
// only support sm_61 or later, platform should have fast native int8 // only support sm_61 or later, platform should have fast native int8
...@@ -81,8 +99,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -81,8 +99,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
using Format = Param::Format; using Format = Param::Format;
auto&& param = args.opr->param(); auto&& param = args.opr->param();
auto&& fm = args.filter_meta; auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), size_t n = args.src_layout->operator[](0),
param); ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
auto&& stream = cuda_stream(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle());
int8_t* filter_ptr = nullptr; int8_t* filter_ptr = nullptr;
...@@ -115,26 +148,39 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -115,26 +148,39 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale,
filter_scale = filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, args.filter_layout->dtype.param<dtype::QuantizedS8>().scale;
bias_scale = float alpha = src_scale * filter_scale;
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, float beta = 1.f;
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; float dst_scale = 1.f;
float alpha = src_scale * filter_scale / dst_scale, if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) {
beta = bias_scale / dst_scale; megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8);
int8_t* z_dev_ptr = nullptr; float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>()
float gamma = 0.0; .scale,
dst_scale =
args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
alpha /= dst_scale, beta = bias_scale / dst_scale;
}
float gamma = 0.f;
if (args.z_layout->ndim > 0) { if (args.z_layout->ndim > 0) {
z_dev_ptr = args.z_tensor->compatible_ptr<int8_t>(); gamma = 1.f;
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8) {
megdnn_assert(args.dst_layout->dtype.enumv() ==
DTypeEnum::QuantizedS8);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>()
.scale;
gamma = z_scale / dst_scale; gamma = z_scale / dst_scale;
} }
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
if (fh == 1 && fw == 1) { if (fh == 1 && fw == 1) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<false>( if (param.format == Format::NCHW4) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<
false>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, args.bias_tensor->compatible_ptr<int32_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, args.z_tensor->compatible_ptr<int8_t>(),
nonlinear_mode, alpha, beta, gamma, dst_scale, args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
m_algo_param.threadblock_k}, m_algo_param.threadblock_k},
...@@ -142,12 +188,36 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -142,12 +188,36 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
m_algo_param.warp_n, m_algo_param.warp_n,
m_algo_param.warp_k}, m_algo_param.warp_k},
stream); stream);
} else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<false>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<float>(),
args.z_tensor->compatible_ptr<float>(),
args.dst_tensor->compatible_ptr<float>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else { } else {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<true>( megdnn_assert(param.format == Format::NCHW4_NCHW32);
}
} else {
if (param.format == Format::NCHW4) {
cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<
true>(
args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, args.src_tensor->compatible_ptr<int8_t>(), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, args.bias_tensor->compatible_ptr<int32_t>(),
args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, args.z_tensor->compatible_ptr<int8_t>(),
nonlinear_mode, alpha, beta, gamma, dst_scale, args.dst_tensor->compatible_ptr<int8_t>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
m_algo_param.threadblock_k}, m_algo_param.threadblock_k},
...@@ -155,7 +225,30 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ...@@ -155,7 +225,30 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
m_algo_param.warp_n, m_algo_param.warp_n,
m_algo_param.warp_k}, m_algo_param.warp_k},
stream); stream);
} else if (param.format == Format::NCHW4_NCHW) {
cutlass_wrapper::
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<true>(
args.src_tensor->compatible_ptr<int8_t>(),
filter_ptr,
args.bias_tensor->compatible_ptr<float>(),
args.z_tensor->compatible_ptr<float>(),
args.dst_tensor->compatible_ptr<float>(), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma,
dst_scale,
cutlass_wrapper::GemmCoord{
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
} }
}
after_kernel_launch();
} }
size_t ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm:: size_t ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::
...@@ -174,8 +267,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( ...@@ -174,8 +267,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess(
using Format = Param::Format; using Format = Param::Format;
auto&& param = args.opr->param(); auto&& param = args.opr->param();
auto&& fm = args.filter_meta; auto&& fm = args.filter_meta;
UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), size_t n = args.src_layout->operator[](0),
param); ci = args.src_layout->operator[](1) * 4,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
size_t co;
if (param.format == Format::NCHW4) {
co = args.dst_layout->operator[](1) * 4;
} else if (param.format == Format::NCHW4_NCHW) {
co = args.dst_layout->operator[](1);
} else {
megdnn_assert(param.format == Format::NCHW4_NCHW32);
co = args.dst_layout->operator[](1) * 32;
}
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()};
src.init_contiguous_stride(); src.init_contiguous_stride();
TensorLayout dst = src; TensorLayout dst = src;
......
...@@ -19,25 +19,28 @@ using namespace cutlass_wrapper; ...@@ -19,25 +19,28 @@ using namespace cutlass_wrapper;
template <typename Convolution> template <typename Convolution>
void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, const typename Convolution::ElementSrc* d_src,
const int8_t* d_z, int8_t* d_dst, int* workspace, const typename Convolution::ElementFilter* d_filter,
const typename Convolution::ElementBias* d_bias,
const typename Convolution::ElementDst* d_z,
typename Convolution::ElementDst* d_dst, int* workspace,
typename Convolution::ConvolutionParameter const& conv_param, typename Convolution::ConvolutionParameter const& conv_param,
typename Convolution::EpilogueOutputOp::Params const& epilogue, typename Convolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream) { cudaStream_t stream) {
typename Convolution::TensorRefSrc tensor_src{ typename Convolution::TensorRefSrc tensor_src{
const_cast<int8_t*>(d_src), const_cast<typename Convolution::ElementSrc*>(d_src),
Convolution::LayoutSrc::packed({conv_param.n(), conv_param.hi(), Convolution::LayoutSrc::packed({conv_param.n(), conv_param.hi(),
conv_param.wi(), conv_param.ci()})}; conv_param.wi(), conv_param.ci()})};
typename Convolution::TensorRefFilter tensor_filter{ typename Convolution::TensorRefFilter tensor_filter{
const_cast<int8_t*>(d_filter), const_cast<typename Convolution::ElementFilter*>(d_filter),
Convolution::LayoutFilter::packed({conv_param.co(), conv_param.fh(), Convolution::LayoutFilter::packed({conv_param.co(), conv_param.fh(),
conv_param.fw(), conv_param.fw(),
conv_param.ci()})}; conv_param.ci()})};
typename Convolution::TensorRefBias tensor_bias{ typename Convolution::TensorRefBias tensor_bias{
const_cast<int32_t*>(d_bias), const_cast<typename Convolution::ElementBias*>(d_bias),
Convolution::LayoutBias::packed({1, 1, 1, conv_param.co()})}; Convolution::LayoutBias::packed({1, 1, 1, conv_param.co()})};
typename Convolution::TensorRefDst tensor_z{ typename Convolution::TensorRefDst tensor_z{
const_cast<int8_t*>(d_z), const_cast<typename Convolution::ElementDst*>(d_z),
Convolution::LayoutDst::packed({conv_param.n(), conv_param.ho(), Convolution::LayoutDst::packed({conv_param.n(), conv_param.ho(),
conv_param.wo(), conv_param.co()})}; conv_param.wo(), conv_param.co()})};
typename Convolution::TensorRefDst tensor_dst{ typename Convolution::TensorRefDst tensor_dst{
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册