提交 96050073 编写于 作者: M Megvii Engine Team 提交者: 王彪

feat(dnn/cuda): add implicit bmm large kernel dwconv2d fprop impl

GitOrigin-RevId: feb09ebb5836d26433c4a82940bb5f22795da381
上级 19fe2e94
...@@ -181,6 +181,8 @@ if(MGE_WITH_CUDA) ...@@ -181,6 +181,8 @@ if(MGE_WITH_CUDA)
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES)
gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop simt CUTLASS_SOURCES)
gen_cutlass_kimpl(dwconv2d_fprop tensorop884 CUTLASS_SOURCES)
list(APPEND SOURCES ${CUTLASS_SOURCES}) list(APPEND SOURCES ${CUTLASS_SOURCES})
list(APPEND SOURCES ${CUSOURCES}) list(APPEND SOURCES ${CUSOURCES})
endif() endif()
......
...@@ -92,6 +92,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { ...@@ -92,6 +92,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : int8_nchw4_dotprod) { for (auto&& algo : int8_nchw4_dotprod) {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }
fill_dwconv_algos();
all_algos.push_back(&int8_chwn4_dotprod); all_algos.push_back(&int8_chwn4_dotprod);
all_algos.push_back(&fallback_nchw_qs8); all_algos.push_back(&fallback_nchw_qs8);
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
...@@ -301,6 +302,32 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { ...@@ -301,6 +302,32 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
} }
#endif #endif
void ConvBiasForwardImpl::AlgoPack::fill_dwconv_algos() {
using AlgoParam = AlgoCutlassConvolutionBase::AlgoParam;
f32_implicit_bmm.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 128, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 64, 8, 64, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 1, 1, 1, 2});
f32_implicit_bmm.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 1, 1, 1, 2});
for (auto&& algo : f32_implicit_bmm) {
all_algos.push_back(&algo);
}
#if CUDA_VERSION >= 10020
f16_implicit_bmm.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
f16_implicit_bmm.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
for (auto&& algo : f16_implicit_bmm) {
all_algos.push_back(&algo);
}
#endif
}
void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam;
int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2});
......
...@@ -84,7 +84,9 @@ public: ...@@ -84,7 +84,9 @@ public:
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4, CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4,
CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4, CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4,
CUDA_FALLBACK_NCHW_INT4 CUDA_FALLBACK_NCHW_INT4,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
}; };
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
...@@ -503,6 +505,8 @@ public: ...@@ -503,6 +505,8 @@ public:
* +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm
* +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm
* + * +
* +--- AlgoFloat32NCHWImplicitBatchedGemm
* +--- AlgoFloat16NCHWHMMAImplicitBatchedGemm
*/ */
/* /*
...@@ -516,7 +520,13 @@ public: ...@@ -516,7 +520,13 @@ public:
// corresponds to cutlass::conv::ConvType. we hope that algo.h does not // corresponds to cutlass::conv::ConvType. we hope that algo.h does not
// depend on cutlass headers // depend on cutlass headers
enum class ConvType { kConvolution, kBatchConvolution, kLocal, kLocalShare }; enum class ConvType {
kConvolution,
kBatchConvolution,
kLocal,
kLocalShare,
kDepthwiseConvolution,
};
// common parameters for operation selection // common parameters for operation selection
struct AlgoParam { struct AlgoParam {
...@@ -558,7 +568,8 @@ public: ...@@ -558,7 +568,8 @@ public:
size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw,
const void* alpha, const void* beta, const void* gamma, const void* delta, const void* alpha, const void* beta, const void* gamma, const void* delta,
const void* theta, const void* threshold, const void* dst_scale, const void* theta, const void* threshold, const void* dst_scale,
cudaStream_t stream, const void* extra_param = nullptr) const; cudaStream_t stream, const void* extra_param = nullptr,
size_t groups = 1) const;
protected: protected:
AlgoParam m_algo_param; AlgoParam m_algo_param;
...@@ -992,6 +1003,54 @@ private: ...@@ -992,6 +1003,54 @@ private:
}; };
#endif #endif
class ConvBiasForwardImpl::AlgoFloat32NCHWFMAImplicitBatchedGemm final
: public AlgoCutlassConvolutionBase {
public:
AlgoFloat32NCHWFMAImplicitBatchedGemm(AlgoParam algo_param)
: AlgoCutlassConvolutionBase(algo_param) {
m_name = ConvBias::algo_name<ConvBias::DirectParam>(
ssprintf(
"FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str()),
ConvBias::DirectParam{});
}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
return 0;
}
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); };
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32);
private:
std::string m_name;
};
class ConvBiasForwardImpl::AlgoFloat16NCHWHMMAImplicitBatchedGemm final
: public AlgoCutlassConvolutionBase {
public:
AlgoFloat16NCHWHMMAImplicitBatchedGemm(AlgoParam algo_param)
: AlgoCutlassConvolutionBase(algo_param) {
m_name = ConvBias::algo_name<ConvBias::DirectParam>(
ssprintf(
"FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM%s",
m_algo_param.to_string().c_str()),
ConvBias::DirectParam{});
}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
return 0;
}
void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); };
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16);
private:
std::string m_name;
};
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public: public:
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
...@@ -1048,6 +1107,8 @@ public: ...@@ -1048,6 +1107,8 @@ public:
std::vector<AlgoInt4Int4NHWCIMMAImplicitGemm> int4_int4_nhwc_imma; std::vector<AlgoInt4Int4NHWCIMMAImplicitGemm> int4_int4_nhwc_imma;
std::vector<AlgoUInt4Int4NHWCIMMAImplicitGemm> uint4_int4_nhwc_imma; std::vector<AlgoUInt4Int4NHWCIMMAImplicitGemm> uint4_int4_nhwc_imma;
#endif #endif
std::vector<AlgoFloat32NCHWFMAImplicitBatchedGemm> f32_implicit_bmm;
std::vector<AlgoFloat16NCHWHMMAImplicitBatchedGemm> f16_implicit_bmm;
AlgoGroupConvGeneral group; AlgoGroupConvGeneral group;
AlgoBFloat16 bfloat16; AlgoBFloat16 bfloat16;
...@@ -1063,6 +1124,7 @@ private: ...@@ -1063,6 +1124,7 @@ private:
#endif #endif
void fill_cudnn_algos(); void fill_cudnn_algos();
void fill_dp4a_algos(); void fill_dp4a_algos();
void fill_dwconv_algos();
}; };
} // namespace cuda } // namespace cuda
......
...@@ -74,13 +74,18 @@ cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { ...@@ -74,13 +74,18 @@ cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) {
return cutlass::conv::ConvType::kLocal; return cutlass::conv::ConvType::kLocal;
case Base::ConvType::kLocalShare: case Base::ConvType::kLocalShare:
return cutlass::conv::ConvType::kLocalShare; return cutlass::conv::ConvType::kLocalShare;
case Base::ConvType::kDepthwiseConvolution:
return cutlass::conv::ConvType::kDepthwiseConvolution;
default: default:
megdnn_assert(0, "invalid conv type"); megdnn_assert(0, "invalid conv type");
} }
} }
NumericTypeID convert_dtype(DTypeEnum dtype) { NumericTypeID convert_dtype(DType dtype) {
switch (dtype) { // just make convolution with no bias happy
if (!dtype.valid())
return NumericTypeID::kF32;
switch (dtype.enumv()) {
case DTypeEnum::Float32: case DTypeEnum::Float32:
return NumericTypeID::kF32; return NumericTypeID::kF32;
case DTypeEnum::Float16: case DTypeEnum::Float16:
...@@ -100,6 +105,21 @@ NumericTypeID convert_dtype(DTypeEnum dtype) { ...@@ -100,6 +105,21 @@ NumericTypeID convert_dtype(DTypeEnum dtype) {
} }
} }
NumericTypeID get_accumulator_dtype(
DType dtype, const param::ConvBias::ComputeMode comp_mode) {
if (dtype.category() == DTypeCategory::QUANTIZED) {
return NumericTypeID::kS32;
} else {
megdnn_assert(dtype.category() == DTypeCategory::FLOAT);
if (comp_mode == param::ConvBias::ComputeMode::DEFAULT) {
return convert_dtype(dtype);
} else {
megdnn_assert(comp_mode == param::ConvBias::ComputeMode::FLOAT32);
return NumericTypeID::kF32;
}
}
}
struct LayoutPack { struct LayoutPack {
LayoutTypeID src; LayoutTypeID src;
LayoutTypeID filter; LayoutTypeID filter;
...@@ -149,6 +169,9 @@ LayoutPack get_layout_pack(const param::ConvBias::Format format, int access_type ...@@ -149,6 +169,9 @@ LayoutPack get_layout_pack(const param::ConvBias::Format format, int access_type
default: default:
megdnn_assert(0, "invalid access_type"); megdnn_assert(0, "invalid access_type");
} }
case Format::NCHW:
return {LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW,
LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW};
default: default:
megdnn_assert(0, "invalid format"); megdnn_assert(0, "invalid format");
} }
...@@ -177,6 +200,93 @@ EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, bool cla ...@@ -177,6 +200,93 @@ EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, bool cla
megdnn_assert(0, "invalid nonlinear mode"); megdnn_assert(0, "invalid nonlinear mode");
} }
std::pair<int, int> get_tensor_alignment(
const param::ConvBias::Format format, const TensorLayout& src,
const TensorLayout& filter, const Base::AlgoParam& algo_param,
bool is_chanwise) {
int alignment_src = 0;
int alignment_filter = 0;
using Format = param::ConvBias::Format;
// get tensor alignment for tensor op operations
// for tensor op operations, the alignment is determined by the size of a vector
auto get_tensor_alignment_tensor_op = [&]() {
switch (format) {
/// case int8
case Format::NCHW32:
case Format::NCHW32_NCHW4:
alignment_src = 16;
alignment_filter = 16;
break;
/// case int4 or uint4
case Format::NCHW64:
alignment_src = 32;
alignment_filter = 32;
break;
case Format::NHWC:
alignment_src = alignment_filter = algo_param.access_size;
break;
default:
megdnn_throw("invalid format");
};
};
// get tensor alignment for dot product operations
// for integer dot product operations, alignment src is always 4
// and the alignment filter is determined by the threadblock shape
auto get_tensor_alignment_dp4a = [&]() {
megdnn_assert(
format == Format::NCHW4 || format == Format::NCHW4_NCHW ||
format == Format::NCHW4_NHWC || format == Format::NCHW4_NCHW32);
alignment_src = 4;
// determine alignment filter
constexpr int warp_size = 32;
int threads = warp_size * algo_param.threadblock_m * algo_param.threadblock_n *
algo_param.threadblock_k /
(algo_param.warp_m * algo_param.warp_n * algo_param.warp_k);
int threadblock_loads = filter.dtype.size(
algo_param.threadblock_m * algo_param.threadblock_n *
algo_param.threadblock_k);
int load_per_thread = threadblock_loads / threads;
if (load_per_thread >= 16)
alignment_filter = 16;
else if (load_per_thread >= 8)
alignment_filter = 8;
else {
megdnn_assert(load_per_thread >= 4);
alignment_filter = 4;
}
};
// get tensor alignment for depthwise convolution
auto get_tensor_alignment_dwconv2d_nchw = [&]() {
alignment_filter = 1;
size_t wi = src.dtype.size(src[3]); // width extent in bytes
for (size_t candidate : {16, 4, 2}) {
if (wi % candidate == 0) {
alignment_src = candidate;
break;
}
}
alignment_src /= src.dtype.size(1);
};
if (format == Format::NCHW32 || format == Format::NCHW32_NCHW4 ||
format == Format::NCHW64 || format == Format::NCHW64) {
get_tensor_alignment_tensor_op();
} else if (
format == Format::NCHW4 || format == Format::NCHW4_NCHW ||
format == Format::NCHW4_NHWC || format == Format::NCHW4_NCHW32) {
get_tensor_alignment_dp4a();
} else {
/// the following is used for depthwise convolution
megdnn_assert(format == Format::NCHW && is_chanwise);
get_tensor_alignment_dwconv2d_nchw();
}
megdnn_assert(alignment_src >= 1 && alignment_filter >= 1);
return {alignment_src, alignment_filter};
}
} // namespace } // namespace
const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op(
...@@ -185,23 +295,36 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co ...@@ -185,23 +295,36 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co
auto&& param = args.opr->param(); auto&& param = args.opr->param();
auto layouts = get_layout_pack(param.format, m_algo_param.access_size); auto layouts = get_layout_pack(param.format, m_algo_param.access_size);
auto epilogue_type = get_epilogue_type( auto epilogue_type = get_epilogue_type(
param.nonlineMode, args.dst_layout->dtype.enumv() != DTypeEnum::Float32); param.nonlineMode,
args.dst_layout->dtype.category() != DTypeCategory::FLOAT);
cutlass::conv::SpecialOptimizeDesc special_optimization = cutlass::conv::SpecialOptimizeDesc special_optimization =
(use_conv_filter_unity_opt) (use_conv_filter_unity_opt)
? cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY ? cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY
: cutlass::conv::SpecialOptimizeDesc::NONE; : cutlass::conv::SpecialOptimizeDesc::NONE;
int alignment_src, alignment_filter;
auto&& fm = args.filter_meta;
bool is_chanwise = param.sparse == param::ConvBias::Sparse::GROUP && fm.icpg == 1 &&
fm.ocpg == 1;
std::tie(alignment_src, alignment_filter) = get_tensor_alignment(
param.format, *args.src_layout, *args.filter_layout, m_algo_param,
is_chanwise);
auto accumulator_dtype =
get_accumulator_dtype(args.src_layout->dtype, param.compute_mode);
ConvolutionKey key{ ConvolutionKey key{
convert_conv_op(conv_op), convert_conv_op(conv_op),
convert_dtype(args.src_layout->dtype.enumv()), convert_dtype(args.src_layout->dtype),
layouts.src, layouts.src,
convert_dtype(args.filter_layout->dtype.enumv()), convert_dtype(args.filter_layout->dtype),
layouts.filter, layouts.filter,
convert_dtype(args.dst_layout->dtype.enumv()), convert_dtype(args.dst_layout->dtype),
layouts.dst, layouts.dst,
convert_dtype(args.bias_layout->dtype.enumv()), convert_dtype(args.bias_layout->dtype),
layouts.bias, layouts.bias,
accumulator_dtype,
convert_conv_type(conv_type), convert_conv_type(conv_type),
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
...@@ -215,6 +338,8 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co ...@@ -215,6 +338,8 @@ const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_co
epilogue_type, epilogue_type,
m_algo_param.stage, m_algo_param.stage,
special_optimization, special_optimization,
alignment_src,
alignment_filter,
without_shared_load}; without_shared_load};
return Singleton::get().operation_table.find_op(key); return Singleton::get().operation_table.find_op(key);
...@@ -227,13 +352,16 @@ void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( ...@@ -227,13 +352,16 @@ void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op(
size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, const void* alpha, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, const void* alpha,
const void* beta, const void* gamma, const void* delta, const void* theta, const void* beta, const void* gamma, const void* delta, const void* theta,
const void* threshold, const void* dst_scale, cudaStream_t stream, const void* threshold, const void* dst_scale, cudaStream_t stream,
const void* extra_param) const { const void* extra_param, size_t groups) const {
// gcc prints warnings when size_t values are implicitly narrowed to int // gcc prints warnings when size_t values are implicitly narrowed to int
cutlass::conv::Conv2dProblemSize problem_size{ cutlass::conv::Conv2dProblemSize problem_size{
int(n), int(hi), int(wi), int(ci), int(n), int(hi), int(wi), int(ci),
int(co), int(fh), int(fw), int(ho), int(co), int(fh), int(fw), int(ho),
int(wo), int(ph), int(pw), int(sh), int(wo), int(ph), int(pw), int(sh),
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation,
1, // split k slices, always 1
int(groups), // groups
};
ConvolutionArguments conv_args{ ConvolutionArguments conv_args{
problem_size, src, filter, bias, z, dst, alpha, problem_size, src, filter, bias, z, dst, alpha,
......
...@@ -71,6 +71,9 @@ public: ...@@ -71,6 +71,9 @@ public:
class AlgoInt4Int4NHWCIMMAImplicitGemm; class AlgoInt4Int4NHWCIMMAImplicitGemm;
class AlgoUInt4Int4NHWCIMMAImplicitGemm; class AlgoUInt4Int4NHWCIMMAImplicitGemm;
class AlgoBFloat16; class AlgoBFloat16;
// The following algorithms are suitable for channel wise convolution
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
class AlgoPack; class AlgoPack;
......
...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: ...@@ -39,6 +39,7 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32, NumericTypeID::kS32,
LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4,
NumericTypeID::kS32,
cutlass::conv::ConvType::kConvolution, cutlass::conv::ConvType::kConvolution,
m_algo_param.threadblock_m, m_algo_param.threadblock_m,
m_algo_param.threadblock_n, m_algo_param.threadblock_n,
...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: ...@@ -52,6 +53,8 @@ const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp,
m_algo_param.stage, m_algo_param.stage,
special_optimization, special_optimization,
4,
16,
false}; false};
return (void*)Singleton::get().operation_table.find_op(key); return (void*)Singleton::get().operation_table.find_op(key);
} }
......
...@@ -223,6 +223,9 @@ enum class ThreadblockSwizzleID { ...@@ -223,6 +223,9 @@ enum class ThreadblockSwizzleID {
kConvolutionFpropTrans, kConvolutionFpropTrans,
kConvolutionDgradNCxHWx, kConvolutionDgradNCxHWx,
kConvolutionDgradTrans, kConvolutionDgradTrans,
kDepthwiseConvolutionFprop,
kDepthwiseConvolutionDgrad,
kDepthwiseConvolutionWgrad,
kInvalid kInvalid
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册