提交 d28eba4e 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cutlass impls for int4 conv bias

GitOrigin-RevId: 878bb8c95511de578cb5c880a412a3e1c9f0dea9
上级 14b65e4d
# Mark generated files as binary, ignore them in git diff. # Mark generated files as binary, ignore them in git diff.
# dnn # dnn
dnn/src/cuda/conv_bias/int4/kimpl/* binary
dnn/src/cuda/conv_bias/int8/kimpl/* binary dnn/src/cuda/conv_bias/int8/kimpl/* binary
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
......
...@@ -84,6 +84,9 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { ...@@ -84,6 +84,9 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : int8_nchw32_imma) { for (auto&& algo : int8_nchw32_imma) {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }
for (auto&& algo : int4_int4_nchw64_imma) {
all_algos.push_back(&algo);
}
#endif #endif
#endif #endif
fill_dp4a_algos(); fill_dp4a_algos();
...@@ -225,6 +228,12 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { ...@@ -225,6 +228,12 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64});
int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64});
} }
{
using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam;
int4_int4_nchw64_imma.emplace_back(AlgoParam{128, 128, 128, 64, 64, 128});
int4_int4_nchw64_imma.emplace_back(AlgoParam{256, 128, 128, 64, 64, 128});
}
#endif #endif
} }
#endif #endif
......
...@@ -61,6 +61,7 @@ public: ...@@ -61,6 +61,7 @@ public:
CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4,
CUDA_BFLOAT16, CUDA_BFLOAT16,
CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
...@@ -755,6 +756,53 @@ public: ...@@ -755,6 +756,53 @@ public:
return ret; return ret;
} }
private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
AlgoParam m_algo_param;
std::string m_name;
};
class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final
: public AlgoBase {
public:
struct AlgoParam {
int threadblock_m;
int threadblock_n;
int threadblock_k;
int warp_m;
int warp_n;
int warp_k;
};
AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param)
: m_algo_param{algo_param} {
m_name = ConvBias::algo_name<ConvBias::DirectParam>(
ssprintf("INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s",
to_string(m_algo_param).c_str()),
ConvBias::DirectParam{});
}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
static std::string to_string(AlgoParam algo_param);
size_t get_preprocess_workspace_in_bytes(
const SizeArgs& args) const override;
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}
private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const; const SizeArgs& args) const;
...@@ -819,6 +867,7 @@ public: ...@@ -819,6 +867,7 @@ public:
#endif #endif
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma; std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma;
std::vector<AlgoInt4Int4NCHW64IMMAImplicitGemm> int4_int4_nchw64_imma;
#endif #endif
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold;
AlgoBFloat16 bfloat16; AlgoBFloat16 bfloat16;
......
...@@ -25,8 +25,8 @@ using namespace megdnn; ...@@ -25,8 +25,8 @@ using namespace megdnn;
using namespace cuda; using namespace cuda;
using namespace cutlass_wrapper; using namespace cutlass_wrapper;
/* ================= cutlass kernel wrapper for nchw32 layout ================ /* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */
*/
#if MEGDNN_TEGRA_X1 #if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem> template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper:: void megdnn::cuda::cutlass_wrapper::
...@@ -149,7 +149,8 @@ INST(true); ...@@ -149,7 +149,8 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
/* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */ /* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */
#if MEGDNN_TEGRA_X1 #if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem> template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper:: void megdnn::cuda::cutlass_wrapper::
...@@ -272,7 +273,8 @@ INST(true); ...@@ -272,7 +273,8 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
/* ================ cutlass kernel wrapper for nchw4 layout ================= */ /* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */
#if MEGDNN_TEGRA_X1 #if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem> template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper:: void megdnn::cuda::cutlass_wrapper::
...@@ -401,7 +403,8 @@ INST(true); ...@@ -401,7 +403,8 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
/* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */ /* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */
#if MEGDNN_TEGRA_X1 #if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem> template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper:: void megdnn::cuda::cutlass_wrapper::
...@@ -531,7 +534,8 @@ INST(true); ...@@ -531,7 +534,8 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
/* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */ /* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */
#if MEGDNN_TEGRA_X1 #if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem> template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper:: void megdnn::cuda::cutlass_wrapper::
...@@ -658,4 +662,125 @@ INST(true); ...@@ -658,4 +662,125 @@ INST(true);
INST(false); INST(false);
#undef INST #undef INST
/* ====== cutlass kernel wrapper for int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const int8_t* /* d_src */, const int8_t* /* d_filter */,
const int32_t* /* d_bias */, const int8_t* /* d_z */,
int8_t* /* d_dst */, int* /* workspace */,
const convolution::ConvParam& /* param */,
uint32_t /* nonlinear_mode */, float /* alpha */,
float /* beta */, float /* gamma */, float /* scale */,
const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {}
#else
template <bool NeedLoadFromConstMem>
void megdnn::cuda::cutlass_wrapper::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const int8_t* d_src, const int8_t* d_filter,
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst,
int* workspace, const convolution::ConvParam& param,
uint32_t nonlinear_mode, float alpha, float beta, float gamma,
float scale, const GemmCoord& threadblock_shape,
const GemmCoord& warp_shape, cudaStream_t stream) {
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
2, 32, 32, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::int4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::int4b_t*>(d_z), \
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using ElementOutput = cutlass::int4b_t;
using ElementAccumulator = int32_t;
using ElementBias = int32_t;
using ElementCompute = float;
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode;
switch (nonlinear_mode) {
case NonlineMode::IDENTITY: {
using EpilogueOp =
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
ElementOutput, 16, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma};
DISPATCH_KERNEL;
}
case NonlineMode::RELU: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationReluClamp<
ElementOutput, 16, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0};
DISPATCH_KERNEL;
}
case NonlineMode::H_SWISH: {
using EpilogueOp = cutlass::epilogue::thread::
BiasAddLinearCombinationHSwishClamp<
ElementOutput, 16, ElementAccumulator, ElementBias,
ElementCompute>;
typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale};
DISPATCH_KERNEL;
}
default:
megdnn_assert(false,
"unsupported nonlinear mode for conv bias operator");
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST(true);
#undef INST
// vim: syntax=cuda.doxygen // vim: syntax=cuda.doxygen
...@@ -76,6 +76,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( ...@@ -76,6 +76,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32(
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int stages, cudaStream_t stream); int stages, cudaStream_t stream);
template <bool NeedLoadFromConstMem>
void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias,
const int8_t* d_z, int8_t* d_dst, int* workspace,
const convolution::ConvParam& param, uint32_t nonlinear_mode,
float alpha, float beta, float gamma, float scale,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
cudaStream_t stream);
} // namespace cutlass_wrapper } // namespace cutlass_wrapper
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
/**
* \file dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
#if CUDA_VERSION >= 10020
bool ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::is_available(
const SizeArgs& args) const {
if (args.bias_layout->ndim <= 0)
return false;
using Param = param::ConvBias;
using Format = Param::Format;
using Sparse = Param::Sparse;
using Mode = Param::Mode;
using NonlineMode = megdnn::param::ConvBias::NonlineMode;
auto&& param = args.opr->param();
if (!check_bias_share_in_channel(*(args.bias_layout), param.format))
return false;
if (param.format != Format::NCHW64 || param.sparse != Sparse::DENSE ||
param.mode != Mode::CROSS_CORRELATION)
return false;
if (param.nonlineMode != NonlineMode::IDENTITY &&
param.nonlineMode != NonlineMode::RELU &&
param.nonlineMode != NonlineMode::H_SWISH)
return false;
if (args.src_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
args.filter_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
args.bias_layout->dtype.enumv() != DTypeEnum::QuantizedS32 ||
args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4)
return false;
if (!is_compute_capability_required(7, 5))
return false;
return true;
}
WorkspaceBundle
ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_bundle(
dt_byte* raw_ptr, const SizeArgs& args) const {
if (args.preprocessed_filter) {
return WorkspaceBundle{raw_ptr, {}};
} else {
size_t ws_filter = args.filter_layout->span().dist_byte();
return WorkspaceBundle{raw_ptr, {ws_filter}};
}
}
size_t
ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec(
const ExecArgs& args) const {
using Format = Param::Format;
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 64,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t co = args.dst_layout->operator[](1) * 64,
ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
auto&& stream = cuda_stream(args.opr->handle());
int8_t* filter_ptr = nullptr;
if (args.preprocessed_filter == nullptr) {
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr);
// reformat filter from nchw64 to chwn64
TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()};
src.init_contiguous_stride();
TensorLayout dst = src;
dst.stride[0] = 64;
dst.stride[1] = co * fh * fw * 64;
dst.stride[2] = co * fw * 64;
dst.stride[3] = co * 64;
dst.stride[4] = 1;
TensorND ts_src, ts_dst;
ts_src.raw_ptr = args.filter_tensor->raw_ptr;
ts_src.layout = src;
ts_dst.raw_ptr = args.workspace.raw_ptr;
ts_dst.layout = dst;
auto&& transpose =
args.opr->handle()->create_operator<RelayoutForward>();
transpose->exec(ts_src, ts_dst);
} else {
filter_ptr = reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr);
}
ConvParam kern_param;
kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
kern_param.fw = fw;
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale,
filter_scale =
args.filter_layout->dtype.param<dtype::QuantizedS4>().scale,
bias_scale =
args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale;
float alpha = src_scale * filter_scale / dst_scale,
beta = bias_scale / dst_scale;
int8_t* z_dev_ptr = nullptr;
float gamma = 0.f;
if (args.z_layout->ndim > 0) {
z_dev_ptr = reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr);
float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale;
gamma = z_scale / dst_scale;
}
uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64<
true>(
reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), filter_ptr,
args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr,
kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale,
cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k},
cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n,
m_algo_param.warp_k},
stream);
}
std::string ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::to_string(
AlgoParam algo_param) {
return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m,
algo_param.threadblock_n, algo_param.threadblock_k,
algo_param.warp_m, algo_param.warp_n, algo_param.warp_k);
}
size_t ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
return 0_z;
}
SmallVector<TensorLayout> ConvBiasForwardImpl::
AlgoInt4Int4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout(
const SizeArgs& args) const {
return {args.filter_layout->collapse_contiguous()};
}
void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess(
const ExecArgs& args) const {
auto&& param = args.opr->param();
auto&& fm = args.filter_meta;
size_t n = args.src_layout->operator[](0),
ci = args.src_layout->operator[](1) * 64,
hi = args.src_layout->operator[](2),
wi = args.src_layout->operator[](3);
size_t co = args.dst_layout->operator[](1) * 64,
ho = args.dst_layout->operator[](2),
wo = args.dst_layout->operator[](3);
UNPACK_CONV_PARAMETER(fm, param);
MARK_USED_VAR
TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()};
src.init_contiguous_stride();
TensorLayout dst = src;
dst.stride[0] = 64;
dst.stride[1] = co * fh * fw * 64;
dst.stride[2] = co * fw * 64;
dst.stride[3] = co * 64;
dst.stride[4] = 1;
TensorND ts_src, ts_dst;
ts_src.raw_ptr = args.filter_tensor->raw_ptr;
ts_src.layout = src;
ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr;
ts_dst.layout = dst;
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>();
transpose->exec(ts_src, ts_dst);
}
#endif
// vim: syntax=cpp.doxygen
../int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl
\ No newline at end of file
...@@ -64,6 +64,7 @@ public: ...@@ -64,6 +64,7 @@ public:
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter;
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth;
class AlgoInt8NCHW32IMMAImplicitGemm; class AlgoInt8NCHW32IMMAImplicitGemm;
class AlgoInt4Int4NCHW64IMMAImplicitGemm;
class AlgoBFloat16; class AlgoBFloat16;
class AlgoPack; class AlgoPack;
......
...@@ -689,7 +689,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { ...@@ -689,7 +689,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) {
} }
TEST_F(CUDA, CUTLASS_WEIGHT_PREPROCESS) { TEST_F(CUDA, CUTLASS_INT8_WEIGHT_PREPROCESS) {
require_compute_capability(6, 1); require_compute_capability(6, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle_cuda()); handle_cuda());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册