未验证 提交 d55f3b6f 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for attention_lstm_fuse_pass, test=develop (#33711)

上级 10171806
...@@ -23,6 +23,61 @@ namespace paddle { ...@@ -23,6 +23,61 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
AttentionLSTMFusePass::AttentionLSTMFusePass() {
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
AddOpCompat(OpCompat("fill_constant"))
.AddInput("ValueTensor")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensorList") // vector<Tensor<int>>
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("dtype")
.IsNumGE(0)
.IsNumLE(25)
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End()
.AddAttr("value")
.IsType<float>()
.End();
AddOpCompat(OpCompat("sequence_expand"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("ref_level")
.IsNumGE(-1)
.End();
}
struct Param { struct Param {
std::string X = "concat_0.tmp_0"; std::string X = "concat_0.tmp_0";
std::string C0 = "cell_init"; std::string C0 = "cell_init";
...@@ -43,7 +98,7 @@ struct Param { ...@@ -43,7 +98,7 @@ struct Param {
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op); void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op);
void FindWhileOp(Graph* graph) { void AttentionLSTMFusePass::FindWhileOp(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::unordered_set<int> fused_external_ops( std::unordered_set<int> fused_external_ops(
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48, {35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
...@@ -60,6 +115,10 @@ void FindWhileOp(Graph* graph) { ...@@ -60,6 +115,10 @@ void FindWhileOp(Graph* graph) {
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
auto* while_pat_node = gpd.pattern().RetrieveNode("while"); auto* while_pat_node = gpd.pattern().RetrieveNode("while");
auto* while_node = subgraph.at(while_pat_node); auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node); marked_nodes.insert(while_node);
......
...@@ -23,8 +23,14 @@ namespace ir { ...@@ -23,8 +23,14 @@ namespace ir {
class Graph; class Graph;
class AttentionLSTMFusePass : public FusePassBase { class AttentionLSTMFusePass : public FusePassBase {
public:
AttentionLSTMFusePass();
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private:
void FindWhileOp(Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -260,7 +260,7 @@ bool OpCompatSensiblePass::IsCompat( ...@@ -260,7 +260,7 @@ bool OpCompatSensiblePass::IsCompat(
auto op_type = node_pair.second->Op()->Type(); auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) { if (!op_compat_judgers_.count(op_type)) {
if (HasOpDef(op_type)) { if (HasOpDef(op_type)) {
LOG(WARNING) << op_type << "compat not registered!"; LOG(WARNING) << op_type << " compat not registered!";
return false; return false;
} }
continue; continue;
......
...@@ -31,6 +31,10 @@ class AttrCompat { ...@@ -31,6 +31,10 @@ class AttrCompat {
AttrCompat(const std::string& attr_name, OpCompat* op_compat) AttrCompat(const std::string& attr_name, OpCompat* op_compat)
: optional_(false), attr_name_(attr_name), op_compat_(op_compat) {} : optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
//! Assert the attribute type is `T`.
template <typename T>
AttrCompat& IsType();
// @{ String-related methods // @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain. //! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringIn(const std::set<std::string>& candidates); AttrCompat& IsStringIn(const std::set<std::string>& candidates);
...@@ -207,6 +211,13 @@ class OpCompatSensiblePass : public Pass { ...@@ -207,6 +211,13 @@ class OpCompatSensiblePass : public Pass {
std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_; std::map<std::string, std::unique_ptr<OpCompat>> op_compat_judgers_;
}; };
template <typename T>
AttrCompat& AttrCompat::IsType() {
conditions_.emplace_back(
[](const Attribute& attr) -> bool { return attr.type() == typeid(T); });
return *this;
}
template <typename T> template <typename T>
AttrCompat& AttrCompat::IsNumGT(T v) { AttrCompat& AttrCompat::IsNumGT(T v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool { conditions_.emplace_back([v](const Attribute& attr) -> bool {
......
...@@ -24,7 +24,6 @@ def { ...@@ -24,7 +24,6 @@ def {
name: "value" name: "value"
type: FLOAT type: FLOAT
} }
} }
extra { extra {
attrs { attrs {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册