提交 2224a252 编写于 作者: M Megvii Engine Team

fix(mge/opr): add opr_footprint support for PoolingBackward

GitOrigin-RevId: 5f1c64ef9a96915d0a39d94682e07d967711548c
上级 91675a71
......@@ -176,7 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
......@@ -223,9 +223,9 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]);
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]);
// mul and add are counted as 2 operations
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos]) / group * 2;
}
......@@ -464,6 +464,14 @@ uint64_t opr_footprint_func<opr::PoolingForward>(cg::OperatorNodeBase* opr) {
return opr->output(0)->shape().total_nr_elems() * area;
}
// PoolingBackWard
template <>
uint64_t opr_footprint_func<opr::PoolingBackward>(cg::OperatorNodeBase* opr) {
auto&& param = opr->cast_final_safe<opr::PoolingBackward>().param();
auto area = param.window_h * param.window_w;
return opr->input()[0]->shape().total_nr_elems() * area;
}
// Concat
template <>
uint64_t opr_footprint_func<opr::Concat>(cg::OperatorNodeBase* opr) {
......@@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul)
REGISTE_PARAM_JSON_FUNC(Dot)
REGISTE_PARAM_JSON_FUNC(MatrixInverse)
REGISTE_PARAM_JSON_FUNC(PoolingForward)
REGISTE_PARAM_JSON_FUNC(PoolingBackward)
REGISTE_PARAM_JSON_FUNC(SVD)
REGISTE_PARAM_JSON_FUNC(MaskConvolution)
REGISTE_PARAM_JSON_FUNC(Images2Neibs)
......@@ -666,7 +675,7 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
{"max_output", json::Number::make(nms_param.max_output)},
});
}
#endif // MGB_ENABLE_JSON
......@@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() {
add_single_comp_footprint<opr::ConvolutionBackwardFilter>();
add_single_comp_footprint<opr::MatrixMul>();
add_single_comp_footprint<opr::PoolingForward>();
add_single_comp_footprint<opr::PoolingBackward>();
add_single_comp_footprint<opr::Concat>();
add_single_comp_footprint<opr::Dimshuffle>();
add_single_comp_footprint<opr::Reduce>();
......@@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Dot>();
add_single_param_json<opr::MatrixInverse>();
add_single_param_json<opr::PoolingForward>();
add_single_param_json<opr::PoolingBackward>();
add_single_param_json<opr::SVD>();
add_single_param_json<opr::MaskConvolution>();
add_single_param_json<opr::Images2Neibs>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册