提交 4cdb7454 编写于 作者: M Megvii Engine Team

feat(rvv/fallback): make nchw44 happly on rvv

GitOrigin-RevId: b29552b4055c4660354a996f5e4134053ce50638
上级 5e306b75
......@@ -25,14 +25,27 @@ void to_handle_bias_and_nonlinear(
void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW) {
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size) {
auto handle = megdnn::inplace_cpu_handle();
auto conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type);
auto conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout};
megdnn::TensorLayout conv_dst_tensor_layout, bias_tensor_layout;
megdnn::TensorND conv_dst_tensor;
if (1 == pack_oc_size) {
conv_dst_tensor_layout = megdnn::TensorLayout({N, OC, OH, OW}, dst_type);
} else {
conv_dst_tensor_layout =
megdnn::TensorLayout({N, OC, OH, OW, pack_oc_size}, dst_type);
}
conv_dst_tensor = megdnn::TensorND{conv_dst_ptr, conv_dst_tensor_layout};
auto dst_tensor = megdnn::TensorND{dst_ptr, conv_dst_tensor_layout};
auto bias_tensor_layout = conv_dst_tensor_layout;
bias_tensor_layout = conv_dst_tensor_layout;
if (megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS == bias_mode) {
bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type);
if (1 == pack_oc_size) {
bias_tensor_layout = megdnn::TensorLayout({1, OC, 1, 1}, bias_type);
} else {
bias_tensor_layout =
megdnn::TensorLayout({1, OC, 1, 1, pack_oc_size}, bias_type);
}
} else if (megdnn::ConvBiasForward::BiasMode::NO_BIAS == bias_mode) {
bias_tensor_layout = megdnn::TensorLayout({}, bias_type);
}
......@@ -52,10 +65,9 @@ struct PostProcess {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW);
dst_type, N, OC, OH, OW, pack_oc_size);
}
};
......@@ -79,10 +91,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW);
dst_type, N, OC, OH, OW, pack_oc_size);
}
};
......@@ -94,13 +105,12 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
megdnn::param::ConvBias::NonlineMode nonlineMode, megdnn::DType bias_type,
megdnn::DType dst_type, size_t N, size_t OC, size_t OH, size_t OW,
size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
return;
}
to_handle_bias_and_nonlinear(
conv_dst_ptr, bias_ptr, dst_ptr, bias_mode, nonlineMode, bias_type,
dst_type, N, OC, OH, OW);
dst_type, N, OC, OH, OW, pack_oc_size);
}
};
......
......@@ -181,10 +181,6 @@ public:
}
#endif
//! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
//! test. so we just disable all im2col and conv1x1 in riscv64
//! FIXME: remove it when impl postprocess for riscv64
#if !MEGDNN_RISCV64
for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
refhold.emplace_back(new AlgoIm2col(
static_cast<MatrixMulImpl::AlgoBase*>(algo), ohw_tile_size));
......@@ -195,7 +191,6 @@ public:
static_cast<MatrixMulImpl::AlgoBase*>(algo), oc_tile_size));
m_all_algos.emplace_back(refhold.back().get());
}
#endif
#if 0
//! As these algos maybe very slow, it will make fastrun search slow, so
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册