未验证 提交 0f0e1979 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #14666 from reyoung/feature/estiminate_flops

Add EstiminateFlops
...@@ -32,7 +32,9 @@ enum OpInfoFillType { ...@@ -32,7 +32,9 @@ enum OpInfoFillType {
kOpProtoAndCheckerMaker = 1, kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2, kGradOpDescMaker = 2,
kVarTypeInference = 3, kVarTypeInference = 3,
kShapeInference = 4 kShapeInference = 4,
kEstimateFlops = 5,
kUnknown = -1
}; };
template <typename T> template <typename T>
...@@ -48,8 +50,10 @@ struct OpInfoFillTypeID { ...@@ -48,8 +50,10 @@ struct OpInfoFillTypeID {
? kVarTypeInference ? kVarTypeInference
: (std::is_base_of<InferShapeBase, T>::value : (std::is_base_of<InferShapeBase, T>::value
? kShapeInference ? kShapeInference
: static_cast<OpInfoFillType>( : (std::is_base_of<EstimateFlopsBase,
-1))))); T>::value
? kEstimateFlops
: kUnknown)))));
} }
}; };
...@@ -139,6 +143,16 @@ struct OpInfoFiller<T, kShapeInference> { ...@@ -139,6 +143,16 @@ struct OpInfoFiller<T, kShapeInference> {
} }
}; };
template <typename T>
struct OpInfoFiller<T, kEstimateFlops> {
void operator()(const char* op_tpe, OpInfo* info) const {
info->estimate_flops_ = [](InferShapeContext* ctx) {
T estimate_flops;
return estimate_flops(ctx);
};
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -31,6 +31,12 @@ class InferShapeBase { ...@@ -31,6 +31,12 @@ class InferShapeBase {
virtual void operator()(InferShapeContext*) const = 0; virtual void operator()(InferShapeContext*) const = 0;
}; };
class EstimateFlopsBase {
public:
virtual ~EstimateFlopsBase() = default;
virtual size_t operator()(InferShapeContext*) const = 0;
};
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
...@@ -38,6 +44,7 @@ struct OpInfo { ...@@ -38,6 +44,7 @@ struct OpInfo {
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
EstimateFlopsFN estimate_flops_;
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
......
...@@ -54,5 +54,7 @@ using InferVarTypeFN = ...@@ -54,5 +54,7 @@ using InferVarTypeFN =
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
using EstimateFlopsFN = std::function<void(InferShapeContext*)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册