提交 00083d13 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(dnn/cuda): fix recursive algo search for fallback_nchw_qs8

GitOrigin-RevId: 6be2991224bced3a38a17b6b888fd4f324d03f9f
上级 bba04f02
......@@ -575,7 +575,10 @@ public:
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
private:
void make_inner_layout(const SizeArgs& args, TensorLayout& inner_src_layout,
TensorLayout& inner_weight_layout,
......
......@@ -69,6 +69,32 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout(
}
};
std::vector<Algorithm::SearchItem>
ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvBiasForwardImpl* o = static_cast<const ConvBiasForwardImpl*>(opr);
SizeArgs args(const_cast<ConvBiasForwardImpl*>(o), layouts[0], layouts[1],
layouts[2], layouts[3], layouts[4], nullptr);
TensorLayout inner_src_layout;
TensorLayout inner_weight_layout;
TensorLayout inner_dst_layout;
TensorLayout inner_bias_layout;
TensorLayout inner_z_layout;
make_inner_layout(args, inner_src_layout, inner_weight_layout,
inner_dst_layout, inner_bias_layout, inner_z_layout);
Param inner_conv_param = o->param();
inner_conv_param.format = Param::Format::NCHW4;
std::string param_str;
Algorithm::serialize_write_pod(inner_conv_param, param_str);
return {{Algorithm::OprType::CONVBIAS_FORWARD,
param_str,
{inner_src_layout, inner_weight_layout, inner_bias_layout,
inner_z_layout, inner_dst_layout}}};
}
bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
......@@ -109,6 +135,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle(
}
auto opr = args.handle->create_operator<ConvBiasForward>();
opr->param() = inner_conv_param;
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
opr.get());
return WorkspaceBundle(
ptr,
{inner_src_layout.span().dist_byte(),
......@@ -164,6 +192,8 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec(
inner_conv_param.format =
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4;
auto inner_opr = args.handle->create_operator<ConvBiasForward>();
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
inner_opr.get());
inner_opr->param() = inner_conv_param;
relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册