diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 36d3e0091ed8c4b17aacc53fdfc9f19895489b01..6d6ff0450c17067e864dd1d119b4ba6f80a2295b 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -176,7 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, }; if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW4_NCHW || - param.format == Param::Format::NCHW4_NHWC || + param.format == Param::Format::NCHW4_NHWC || param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || @@ -223,9 +223,9 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, uint64_t fh = static_cast(filter_shape[spatial_start]); uint64_t fw = static_cast(filter_shape[spatial_start + 1]); - + // mul and add are counted as 2 operations - + return dst_shape.total_nr_elems() * fh * fw * static_cast(src_shape[cpos]) / group * 2; } @@ -464,6 +464,14 @@ uint64_t opr_footprint_func(cg::OperatorNodeBase* opr) { return opr->output(0)->shape().total_nr_elems() * area; } +// PoolingBackWard +template <> +uint64_t opr_footprint_func(cg::OperatorNodeBase* opr) { + auto&& param = opr->cast_final_safe().param(); + auto area = param.window_h * param.window_w; + return opr->input()[0]->shape().total_nr_elems() * area; +} + // Concat template <> uint64_t opr_footprint_func(cg::OperatorNodeBase* opr) { @@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul) REGISTE_PARAM_JSON_FUNC(Dot) REGISTE_PARAM_JSON_FUNC(MatrixInverse) REGISTE_PARAM_JSON_FUNC(PoolingForward) +REGISTE_PARAM_JSON_FUNC(PoolingBackward) REGISTE_PARAM_JSON_FUNC(SVD) REGISTE_PARAM_JSON_FUNC(MaskConvolution) REGISTE_PARAM_JSON_FUNC(Images2Neibs) @@ -666,7 +675,7 @@ std::shared_ptr opr_param_json_func( {"max_output", json::Number::make(nms_param.max_output)}, }); } - + #endif // MGB_ENABLE_JSON @@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() { add_single_comp_footprint(); add_single_comp_footprint(); add_single_comp_footprint(); + add_single_comp_footprint(); add_single_comp_footprint(); add_single_comp_footprint(); add_single_comp_footprint(); @@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); + add_single_param_json(); add_single_param_json(); add_single_param_json(); add_single_param_json();