提交 4cdb7454 编写于 作者: M Megvii Engine Team

feat(rvv/fallback): make nchw44 happly on rvv

GitOrigin-RevId: b29552b4055c4660354a996f5e4134053ce50638
上级 5e306b75
...@@ -25,14 +25,27 @@ void to_handle_bias_and_nonlinear( ...@@ -25,14 +25,27 @@ void to_handle_bias_and_nonlinear(
void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
megdnn::ConvBiasForward::BiasMode bias_mode, megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW) { megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size) {
auto handle = megdnn::inplace_cpu_handle(); auto handle = megdnn::inplace_cpu_handle();
auto conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type); megdnn::TensorLayout conv_dst_tensor_layout, bias_tensor_layout;
auto conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout}; megdnn::TensorND conv_dst_tensor;
if (1 == pack_oc_size) {
conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type);
} else {
conv_dst_tensor_layout =
megdnn::TensorLayout({N, OC, OH, OW, pack_oc_size}, dst_type);
}
conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout};
auto dst_tensor = megdnn::TensorND{dst_ptr, conv_dst_tensor_layout}; auto dst_tensor = megdnn::TensorND{dst_ptr, conv_dst_tensor_layout};
auto bias_tensor_layout = conv_dst_tensor_layout; bias_tensor_layout = conv_dst_tensor_layout;
if (megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS == bias_mode) { if (megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS == bias_mode) {
bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type); if (1 == pack_oc_size) {
bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type);
} else {
bias_tensor_layout =
megdnn::TensorLayout({1, OC, 1, 1, pack_oc_size}, bias_type);
}
} else if (megdnn::ConvBiasForward::BiasMode::NO_BIAS == bias_mode) { } else if (megdnn::ConvBiasForward::BiasMode::NO_BIAS == bias_mode) {
bias_tensor_layout = megdnn::TensorLayout({}, bias_type); bias_tensor_layout = megdnn::TensorLayout({}, bias_type);
} }
...@@ -52,10 +65,9 @@ struct PostProcess { ...@@ -52,10 +65,9 @@ struct PostProcess {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) { size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
to_handle_bias_and_nonlinear( to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW); dst_type, N, OC, OH, OW, pack_oc_size);
} }
}; };
...@@ -79,10 +91,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { ...@@ -79,10 +91,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) { size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
to_handle_bias_and_nonlinear( to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW); dst_type, N, OC, OH, OW, pack_oc_size);
} }
}; };
...@@ -94,13 +105,12 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { ...@@ -94,13 +105,12 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type, megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) { size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
return; return;
} }
to_handle_bias_and_nonlinear( to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type, conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW); dst_type, N, OC, OH, OW, pack_oc_size);
} }
}; };
......
...@@ -181,10 +181,6 @@ public: ...@@ -181,10 +181,6 @@ public:
} }
#endif #endif
//! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
//! test. so we just disable all im2col and conv1x1 in riscv64
//! FIXME: remove it when impl postprocess for riscv64
#if !MEGDNN_RISCV64
for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) { for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
refhold.emplace_back(new AlgoIm2col( refhold.emplace_back(new AlgoIm2col(
static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size)); static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size));
...@@ -195,7 +191,6 @@ public: ...@@ -195,7 +191,6 @@ public:
static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size)); static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size));
m_all_algos.emplace_back(refhold.back().get()); m_all_algos.emplace_back(refhold.back().get());
} }
#endif
#if 0 #if 0
//! As these algos maybe very slow, it will make fastrun search slow, so //! As these algos maybe very slow, it will make fastrun search slow, so
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册