diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index eea7e712f8f6e187cdceedce77cc76d1d4ca2101..1ce18c3d6b26c541beed668a113b7a4de7f0e79e 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -32,7 +32,9 @@ enum OpInfoFillType { kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, - kShapeInference = 4 + kShapeInference = 4, + kEstimateFlops = 5, + kUnknown = -1 }; template @@ -48,8 +50,10 @@ struct OpInfoFillTypeID { ? kVarTypeInference : (std::is_base_of::value ? kShapeInference - : static_cast( - -1))))); + : (std::is_base_of::value + ? kEstimateFlops + : kUnknown))))); } }; @@ -139,6 +143,16 @@ struct OpInfoFiller { } }; +template +struct OpInfoFiller { + void operator()(const char* op_tpe, OpInfo* info) const { + info->estimate_flops_ = [](InferShapeContext* ctx) { + T estimate_flops; + return estimate_flops(ctx); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 19e5c2c73eac74dee030a4f7820531800f737e4e..e0bf5ed999f580f217af285bf97d0bc0232f1ded 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -31,6 +31,12 @@ class InferShapeBase { virtual void operator()(InferShapeContext*) const = 0; }; +class EstimateFlopsBase { + public: + virtual ~EstimateFlopsBase() = default; + virtual size_t operator()(InferShapeContext*) const = 0; +}; + struct OpInfo { OpCreator creator_; GradOpMakerFN grad_op_maker_; @@ -38,6 +44,7 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + EstimateFlopsFN estimate_flops_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 2de6233a9e0d320ec9a06d547db3575eb61925c0..cdc5fa6862e3b2a2784151302f15540a0e9db8ff 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -54,5 +54,7 @@ using InferVarTypeFN = using InferShapeFN = std::function; +using EstimateFlopsFN = std::function; + } // namespace framework } // namespace paddle