提交 63032170 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add gi fp16 nchw88 winograd F63 algo

GitOrigin-RevId: d986e1cbebd0f9ad89c27a62bd4e951459165d3d
上级 10503cfa
...@@ -91,5 +91,43 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( ...@@ -91,5 +91,43 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP16WinogradF23_8x8_NCHW88, winograd::winograd_F23_mk8_f16_nchw88, AlgoFP16WinogradF23_8x8_NCHW88, winograd::winograd_F23_mk8_f16_nchw88,
megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8); megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8);
/* =================== AlgoFP16WinogradF63_8x8_NCHW88 ===================== */
bool ConvBiasImpl::AlgoFP16WinogradF63_8x8_NCHW88::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(
megdnn_fallback_winograd_fp16_nchw88,
midout_iv("AlgoFP16WinogradF63_8x8_NCHW88"_hash)) {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_F63_mk8_f16_nchw88;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() ==
fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK &&
param.filter_meta.format == param::ConvBias::Format::NCHW88 &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float16;
}
MIDOUT_END();
return false;
}
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP16WinogradF63_8x8_NCHW88, winograd::winograd_F63_mk8_f16_nchw88,
megdnn_fallback_winograd_fp16_nchw88, param::MatrixMul::Format::MK8);
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -46,6 +46,24 @@ public: ...@@ -46,6 +46,24 @@ public:
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16) MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16)
}; };
class ConvBiasImpl::AlgoFP16WinogradF63_8x8_NCHW88 final : public AlgoBase {
public:
AlgoFP16WinogradF63_8x8_NCHW88(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {8, 6, m_tile_size, 3},
param::ConvBias::Format::NCHW88);
}
return m_name.c_str();
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16)
};
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
......
...@@ -7,4 +7,5 @@ ...@@ -7,4 +7,5 @@
#define MULSF16 GiMultiplyScalerFloat16 #define MULSF16 GiMultiplyScalerFloat16
#endif #endif
#define CONCAT(a, idx) a##idx
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -17,6 +17,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY( ...@@ -17,6 +17,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY(
MEGDNN_REG_WINOGRAD_STRATEGY( MEGDNN_REG_WINOGRAD_STRATEGY(
dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 8, 8, dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 8, 8,
winograd_F23_mk8_f16_nchw88) winograd_F23_mk8_f16_nchw88)
MEGDNN_REG_WINOGRAD_STRATEGY(
dt_float16, dt_float16, dt_float16, dt_float16, 6, 3, 8, 8,
winograd_F63_mk8_f16_nchw88)
} // namespace winograd } // namespace winograd
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
......
...@@ -204,6 +204,11 @@ public: ...@@ -204,6 +204,11 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get()); m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF63_8x8_NCHW88(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
} }
} }
#endif #endif
......
...@@ -228,6 +228,7 @@ public: ...@@ -228,6 +228,7 @@ public:
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16, GI_COMMON_WINOGRAD_F23_8X8_NCHW88_F16,
GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16, GI_COMMON_WINOGRAD_F43_8X8_NCHW88_F16,
GI_COMMON_WINOGRAD_F63_8X8_NCHW88_F16,
GI_COMMON_DIRECT_FP32, GI_COMMON_DIRECT_FP32,
GI_COMMON_DIRECT_STRD1_FP32, GI_COMMON_DIRECT_STRD1_FP32,
GI_COMMON_DIRECT_STRD2_FP32, GI_COMMON_DIRECT_STRD2_FP32,
...@@ -397,6 +398,7 @@ private: ...@@ -397,6 +398,7 @@ private:
class AlgoFP16WinogradF23_8x8_NCHW88; class AlgoFP16WinogradF23_8x8_NCHW88;
class AlgoFP16WinogradF43_8x8_NCHW88; class AlgoFP16WinogradF43_8x8_NCHW88;
class AlgoFP16WinogradF63_8x8_NCHW88;
class AlgoF32Direct; class AlgoF32Direct;
class AlgoF32DirectStride1; class AlgoF32DirectStride1;
......
...@@ -634,6 +634,19 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_8_NCHW88_FP16) { ...@@ -634,6 +634,19 @@ TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_8_NCHW88_FP16) {
"8:4:", checker, args, &rng, 0.006, param::MatrixMul::Format::MK8, "8:4:", checker, args, &rng, 0.006, param::MatrixMul::Format::MK8,
"WINOGRAD_NCHW88"); "WINOGRAD_NCHW88");
} }
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F63_8_NCHW88_FP16) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw88_conv_bias_args({3}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
Float16PeriodicalRNG rng(0x3c00);
check_winograd_fp16(
"8:6:", checker, args, &rng, 0.019, param::MatrixMul::Format::MK8,
"WINOGRAD_NCHW88");
}
#endif #endif
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_WEIGHT_PREPROCESS) { TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_F43_4_WEIGHT_PREPROCESS) {
...@@ -1407,7 +1420,7 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F23_FP32_NCHW44_VS_FP16_NCHW88) { ...@@ -1407,7 +1420,7 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F23_FP32_NCHW44_VS_FP16_NCHW88) {
std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:2"; std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:2";
benchmark_with_contrast( benchmark_with_contrast(
args_with_computation_fp16, algo_name_fp16, data_type_fp16, args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}}); args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {4}});
} }
TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) { TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) {
...@@ -1447,6 +1460,44 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) { ...@@ -1447,6 +1460,44 @@ TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F43_FP32_NCHW44_VS_FP16_NCHW88) {
args_with_computation_fp16, algo_name_fp16, data_type_fp16, args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}}); args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {0}});
} }
TEST_F(FALLBACK, BENCHMARK_GI_WINOGRAD_F63_FP32_NCHW44_VS_FP16_NCHW88) {
auto&& args_fp16 = conv_bias::get_winograd_benchmark_args(3, 8, 8);
auto&& args_fp32 = conv_bias::get_winograd_benchmark_args(3, 4, 4);
auto cal_computation = [](const conv_bias::TestArg& arg) {
TensorShape dst_shape{
arg.src[0], arg.filter[0],
(arg.src[2] + arg.param.pad_h * 2 - arg.filter[2]) /
arg.param.stride_h +
1,
(arg.src[3] + arg.param.pad_w * 2 - arg.filter[3]) /
arg.param.stride_w +
1,
arg.filter[5]};
return dst_shape.total_nr_elems() * arg.filter[1] * arg.filter[2] *
arg.filter[3] * arg.filter[4] * 2.0 / (1024 * 1024 * 1024) * 1e3;
};
std::vector<std::pair<conv_bias::TestArg, float>> args_with_computation_fp16,
args_with_computation_fp32;
for (const auto& arg : args_fp16) {
args_with_computation_fp16.emplace_back(arg, cal_computation(arg));
}
for (const auto& arg : args_fp32) {
args_with_computation_fp32.emplace_back(arg, cal_computation(arg));
}
std::vector<DType> data_type_fp16 = {
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()};
std::vector<DType> data_type_fp32 = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
std::string algo_name_fp16 = "WINOGRAD_NCHW88:FB_GI_F16_MK8_8x8:8:6";
std::string algo_name_fp32 = "WINOGRAD_NCHW44:FB_GI_F32_MK4_4x8:4:6";
benchmark_with_contrast(
args_with_computation_fp16, algo_name_fp16, data_type_fp16,
args_with_computation_fp32, algo_name_fp32, data_type_fp32, 10, {1, {4}});
}
#endif #endif
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册