提交 2224a252 编写于 作者: M Megvii Engine Team

fix(mge/opr): add opr_footprint support for PoolingBackward

GitOrigin-RevId: 5f1c64ef9a96915d0a39d94682e07d967711548c
上级 91675a71
...@@ -176,7 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -176,7 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
}; };
if (param.format == Param::Format::NCHW4 || if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW || param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NHWC || param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44 ||
...@@ -223,9 +223,9 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -223,9 +223,9 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]);
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]);
// mul and add are counted as 2 operations // mul and add are counted as 2 operations
return dst_shape.total_nr_elems() * fh * fw * return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos]) / group * 2; static_cast<uint64_t>(src_shape[cpos]) / group * 2;
} }
...@@ -464,6 +464,14 @@ uint64_t opr_footprint_func<opr::PoolingForward>(cg::OperatorNodeBase* opr) { ...@@ -464,6 +464,14 @@ uint64_t opr_footprint_func<opr::PoolingForward>(cg::OperatorNodeBase* opr) {
return opr->output(0)->shape().total_nr_elems() * area; return opr->output(0)->shape().total_nr_elems() * area;
} }
// PoolingBackWard
template <>
uint64_t opr_footprint_func<opr::PoolingBackward>(cg::OperatorNodeBase* opr) {
auto&& param = opr->cast_final_safe<opr::PoolingBackward>().param();
auto area = param.window_h * param.window_w;
return opr->input()[0]->shape().total_nr_elems() * area;
}
// Concat // Concat
template <> template <>
uint64_t opr_footprint_func<opr::Concat>(cg::OperatorNodeBase* opr) { uint64_t opr_footprint_func<opr::Concat>(cg::OperatorNodeBase* opr) {
...@@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul) ...@@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul)
REGISTE_PARAM_JSON_FUNC(Dot) REGISTE_PARAM_JSON_FUNC(Dot)
REGISTE_PARAM_JSON_FUNC(MatrixInverse) REGISTE_PARAM_JSON_FUNC(MatrixInverse)
REGISTE_PARAM_JSON_FUNC(PoolingForward) REGISTE_PARAM_JSON_FUNC(PoolingForward)
REGISTE_PARAM_JSON_FUNC(PoolingBackward)
REGISTE_PARAM_JSON_FUNC(SVD) REGISTE_PARAM_JSON_FUNC(SVD)
REGISTE_PARAM_JSON_FUNC(MaskConvolution) REGISTE_PARAM_JSON_FUNC(MaskConvolution)
REGISTE_PARAM_JSON_FUNC(Images2Neibs) REGISTE_PARAM_JSON_FUNC(Images2Neibs)
...@@ -666,7 +675,7 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>( ...@@ -666,7 +675,7 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
{"max_output", json::Number::make(nms_param.max_output)}, {"max_output", json::Number::make(nms_param.max_output)},
}); });
} }
#endif // MGB_ENABLE_JSON #endif // MGB_ENABLE_JSON
...@@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() { ...@@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() {
add_single_comp_footprint<opr::ConvolutionBackwardFilter>(); add_single_comp_footprint<opr::ConvolutionBackwardFilter>();
add_single_comp_footprint<opr::MatrixMul>(); add_single_comp_footprint<opr::MatrixMul>();
add_single_comp_footprint<opr::PoolingForward>(); add_single_comp_footprint<opr::PoolingForward>();
add_single_comp_footprint<opr::PoolingBackward>();
add_single_comp_footprint<opr::Concat>(); add_single_comp_footprint<opr::Concat>();
add_single_comp_footprint<opr::Dimshuffle>(); add_single_comp_footprint<opr::Dimshuffle>();
add_single_comp_footprint<opr::Reduce>(); add_single_comp_footprint<opr::Reduce>();
...@@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() { ...@@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Dot>(); add_single_param_json<opr::Dot>();
add_single_param_json<opr::MatrixInverse>(); add_single_param_json<opr::MatrixInverse>();
add_single_param_json<opr::PoolingForward>(); add_single_param_json<opr::PoolingForward>();
add_single_param_json<opr::PoolingBackward>();
add_single_param_json<opr::SVD>(); add_single_param_json<opr::SVD>();
add_single_param_json<opr::MaskConvolution>(); add_single_param_json<opr::MaskConvolution>();
add_single_param_json<opr::Images2Neibs>(); add_single_param_json<opr::Images2Neibs>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册