From 589b863b986b4ab3bae4c572c89a201df4a3edc7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 29 Nov 2018 17:20:20 +0800 Subject: [PATCH] Add EstiminateFlops test=develop --- paddle/fluid/framework/details/op_registry.h | 20 +++++++++++++++++--- paddle/fluid/framework/op_info.h | 7 +++++++ paddle/fluid/framework/type_defs.h | 2 ++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index eea7e712f8f..1ce18c3d6b2 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -32,7 +32,9 @@ enum OpInfoFillType { kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, - kShapeInference = 4 + kShapeInference = 4, + kEstimateFlops = 5, + kUnknown = -1 }; template @@ -48,8 +50,10 @@ struct OpInfoFillTypeID { ? kVarTypeInference : (std::is_base_of::value ? kShapeInference - : static_cast( - -1))))); + : (std::is_base_of::value + ? kEstimateFlops + : kUnknown))))); } }; @@ -139,6 +143,16 @@ struct OpInfoFiller { } }; +template +struct OpInfoFiller { + void operator()(const char* op_tpe, OpInfo* info) const { + info->estimate_flops_ = [](InferShapeContext* ctx) { + T estimate_flops; + return estimate_flops(ctx); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 19e5c2c73ea..e0bf5ed999f 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -31,6 +31,12 @@ class InferShapeBase { virtual void operator()(InferShapeContext*) const = 0; }; +class EstimateFlopsBase { + public: + virtual ~EstimateFlopsBase() = default; + virtual size_t operator()(InferShapeContext*) const = 0; +}; + struct OpInfo { OpCreator creator_; GradOpMakerFN grad_op_maker_; @@ -38,6 +44,7 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + EstimateFlopsFN estimate_flops_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 2de6233a9e0..cdc5fa6862e 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -54,5 +54,7 @@ using InferVarTypeFN = using InferShapeFN = std::function; +using EstimateFlopsFN = std::function; + } // namespace framework } // namespace paddle -- GitLab