提交 2224a252 编写于 作者: M Megvii Engine Team

fix(mge/opr): add opr_footprint support for PoolingBackward

GitOrigin-RevId: 5f1c64ef9a96915d0a39d94682e07d967711548c
上级 91675a71
......@@ -464,6 +464,14 @@ uint64_t opr_footprint_func<opr::PoolingForward>(cg::OperatorNodeBase* opr) {
return opr->output(0)->shape().total_nr_elems() * area;
}
// PoolingBackWard
template <>
uint64_t opr_footprint_func<opr::PoolingBackward>(cg::OperatorNodeBase* opr) {
auto&& param = opr->cast_final_safe<opr::PoolingBackward>().param();
auto area = param.window_h * param.window_w;
return opr->input()[0]->shape().total_nr_elems() * area;
}
// Concat
template <>
uint64_t opr_footprint_func<opr::Concat>(cg::OperatorNodeBase* opr) {
......@@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul)
REGISTE_PARAM_JSON_FUNC(Dot)
REGISTE_PARAM_JSON_FUNC(MatrixInverse)
REGISTE_PARAM_JSON_FUNC(PoolingForward)
REGISTE_PARAM_JSON_FUNC(PoolingBackward)
REGISTE_PARAM_JSON_FUNC(SVD)
REGISTE_PARAM_JSON_FUNC(MaskConvolution)
REGISTE_PARAM_JSON_FUNC(Images2Neibs)
......@@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() {
add_single_comp_footprint<opr::ConvolutionBackwardFilter>();
add_single_comp_footprint<opr::MatrixMul>();
add_single_comp_footprint<opr::PoolingForward>();
add_single_comp_footprint<opr::PoolingBackward>();
add_single_comp_footprint<opr::Concat>();
add_single_comp_footprint<opr::Dimshuffle>();
add_single_comp_footprint<opr::Reduce>();
......@@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Dot>();
add_single_param_json<opr::MatrixInverse>();
add_single_param_json<opr::PoolingForward>();
add_single_param_json<opr::PoolingBackward>();
add_single_param_json<opr::SVD>();
add_single_param_json<opr::MaskConvolution>();
add_single_param_json<opr::Images2Neibs>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册