diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 01536fd36ff83bc53cce7cbbe077bfc6c8fd95b4..7e7f1fed5ad58db25909c25ca60f5eac80a5f478 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector) cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) -cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) +cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h index ce7635bb35ce6108b4a5a356c8fb99269dbf2890..bc5fc2a16d3939648f53e91f6cd3f4f0def0fd93 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.h +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" #include "paddle/fluid/framework/scope.h" namespace paddle { @@ -46,7 +46,7 @@ enum FuseOptions { FUSE_MKLDNN // fusing will be done with MKL-DNN }; -class FusePassBase : public Pass { +class FusePassBase : public OpCompatSensiblePass { public: void Init(const std::string& repr, Graph* graph) const; Scope* param_scope() const; diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index f7312ca5555311351b7bfeeb6b4b11ca24867ca5..b056c3b07a2f65bf0756285857edd3355b591c29 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" - +#include "paddle/fluid/framework/op_info.h" namespace paddle { namespace framework { namespace ir { @@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set& candidates) { } //! Todo: append the definition. -AttrCompat& AttrCompat::IsLeftDefault() { return *this; } +AttrCompat& AttrCompat::IsLeftDefault() { + const std::string& op_name = op_compat_->Name(); + if (!OpInfoMap::Instance().Has(op_name)) { + VLOG(3) << "Op (" << op_name << ") is not registered!"; + conditions_.emplace_back([](const Attribute& attr) { return false; }); + return *this; + } + const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); + const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap(); + if (attrs.find(attr_name_) == attrs.end()) { + VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_; + conditions_.emplace_back([](const Attribute& attr) { return false; }); + } else { + Attribute default_attr = attrs.at(attr_name_); + conditions_.emplace_back([default_attr](const Attribute& attr) -> bool { + return attr == default_attr; + }); + } + return *this; +} bool AttrCompat::operator()(const OpDesc& op_desc) { + if (conditions_.empty()) { + return true; + } if (!op_desc.HasAttr(attr_name_)) { - return false; + return optional_; } const Attribute attr = op_desc.GetAttr(attr_name_); for (auto& func : conditions_) { @@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) { } return true; } +AttrCompat& AttrCompat::IsOptional() { + optional_ = true; + return *this; +} AttrCompat& AttrCompat::IsBoolEQ(bool v) { conditions_.emplace_back([v](const Attribute& attr) -> bool { @@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()( } AttrCompat& OpCompat::AddAttr(const std::string& attr_name) { - attr_compats_.emplace_back(attr_name, this); - return attr_compats_.back(); + PADDLE_ENFORCE_EQ( + attr_compats_.find(attr_name), attr_compats_.end(), + platform::errors::InvalidArgument( + "The attrubute compat with the same name has been added")); + attr_compats_.emplace(attr_name, AttrCompat(attr_name, this)); + return attr_compats_.at(attr_name); } InputOrOutputCompat& OpCompat::AddInput(const std::string& name) { @@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { } bool OpCompat::Judge(const OpDesc& op_desc) { + for (auto& attr_map : op_desc.GetAttrMap()) { + if (attr_compats_.find(attr_map.first) == attr_compats_.end()) { + if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) { + VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_ + << ") not reigistered in OpCompat, not equal to default value!"; + return false; + } + } + } for (auto& attr_compat : attr_compats_) { - if (!attr_compat(op_desc)) { + if (!attr_compat.second(op_desc)) { + VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op(" + << op_name_ << ") failed!"; return false; } } @@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& input_desc : inputs_map) { if (input_compats_.find(input_desc.first) == input_compats_.end()) { if (!input_desc.second.empty()) { + VLOG(3) << "The Input (" << input_desc.first << ") of Operator (" + << op_name_ << ") not reigistered in OpCompat!"; return false; } } @@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& input_val : input_compats_) { if (inputs_map.find(input_val.first) == inputs_map.end()) { if (!input_val.second.Optional()) { + VLOG(3) << "The No optional Input (" << input_val.first + << ") of Operator (" << op_name_ << ") not find in op_desc!"; return false; } } else { if (!input_val.second(inputs_map.at(input_val.first))) { + VLOG(3) << "The Input (" << input_val.first << ") of Operator (" + << op_name_ << ") compat check failed!"; return false; } } @@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& output_desc : outputs_map) { if (output_compats_.find(output_desc.first) == output_compats_.end()) { if (!output_desc.second.empty()) { + VLOG(3) << "The Output (" << output_desc.first << ") of Operator (" + << op_name_ << ") not reigistered in OpCompat!"; return false; } } @@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& output_val : output_compats_) { if (outputs_map.find(output_val.first) == outputs_map.end()) { if (!output_val.second.Optional()) { + VLOG(3) << "The No optional Output (" << output_val.first + << ") of Operator (" << op_name_ << ") not find in op_desc!"; return false; } } else { if (!output_val.second(outputs_map.at(output_val.first))) { + VLOG(3) << "The Output (" << output_val.first << ") of Operator (" + << op_name_ << ") compat check failed!"; return false; } } diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 6c0860549fbfeecb9a10a279b2a85fc792ad4089..3f2ea673d879b8f1ca3ddbed82b6120af5044d47 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -29,7 +29,7 @@ class OpCompat; class AttrCompat { public: AttrCompat(const std::string& attr_name, OpCompat* op_compat) - : attr_name_(attr_name), op_compat_(op_compat) {} + : optional_(false), attr_name_(attr_name), op_compat_(op_compat) {} // @{ String-related methods //! Assert the attribute is an string in the `candidates` domain. @@ -70,12 +70,15 @@ class AttrCompat { //! Tell whether this attribute is left as default value. AttrCompat& IsLeftDefault(); + AttrCompat& IsOptional(); + //! Jump back to retrieve OpCompat instance. OpCompat& End() { return *op_compat_; } bool operator()(const OpDesc& op_desc); private: + bool optional_; std::string attr_name_; OpCompat* op_compat_; std::vector> conditions_; @@ -134,7 +137,7 @@ class OpCompat { private: std::string op_name_; - std::vector attr_compats_; + std::unordered_map attr_compats_; std::unordered_map input_compats_; std::unordered_map output_compats_; }; @@ -179,15 +182,6 @@ class OpCompat { * }; */ class OpCompatSensiblePass : public Pass { - public: - //! Access the subgraph and pattern. - void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - if (IsCompat(subgraph, g)) { - AccessSubgraphImpl(subgraph, g); - } - } - protected: /** * Developer should push the compatibility `teller` for each kind of Op in the @@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass { */ OpCompat& AddOpCompat(OpCompat&& op_compat); - //! Modify the subgraph. - virtual bool AccessSubgraphImpl( - const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const { - return true; - } - //! Tell the Op compability of a subgraph. bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const { @@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass { // Check the all the ops in the subgraph are contained in the // op_compat. for (auto& node_pair : subgraph) { - if (!node_pair.first->IsOp()) continue; + if (!node_pair.second->IsOp()) continue; auto op_type = node_pair.second->Op()->Type(); if (!op_compat_judgers_.count(op_type)) { return false; diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc index 3d0863a6d12d9538a57dfaef0160f129b8ff00ac..0878e4d9890d35bc4ecdf276880b43e9c5f4f416 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" - #include "gtest/gtest.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" namespace paddle { @@ -23,7 +23,7 @@ namespace ir { TEST(OpCompatSensiblePass, compatOp) { auto lambda = [](const std::string& str) { return str == "tanh"; }; - OpCompat compat("FC"); + OpCompat compat("fc"); compat.AddAttr("in_num_col_dims") .IsIntIn({1, 2}) .IsNumLE(1) @@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) { fc_op.SetInput("Bias", std::vector{"test_input_1"}); fc_op.SetOutput("Out", std::vector{"test_output"}); - EXPECT_STREQ(compat.Name().c_str(), "FC"); + EXPECT_STREQ(compat.Name().c_str(), "fc"); + EXPECT_FALSE(compat.Judge(fc_op)); +} + +TEST(OpCompatSensiblePass, compatOpAttribute) { + OpCompat compat("fc"); + + OpDesc fc_op; + + std::unordered_map attr_map; + attr_map["in_num_col_dims"] = 1; + fc_op.SetAttrMap(attr_map); + + OpInfo info; + info.checker_ = new OpAttrChecker(); + OpInfoMap::Instance().Insert("fc", info); + + EXPECT_FALSE(compat.Judge(fc_op)); + + info.checker_->AddAttrChecker("in_num_col_dims").SetDefault(1); + + EXPECT_TRUE(compat.Judge(fc_op)); + delete info.checker_; +} + +TEST(OpCompatSensiblePass, compatOpAttributeOptional) { + OpCompat compat("fc"); + compat.AddAttr("activation_type") + .IsOptional() + .IsStringIn({"tanh", "sigmoid"}); + OpDesc fc_op; EXPECT_TRUE(compat.Judge(fc_op)); } +TEST(OpCompatSensiblePass, compatOpInput) { + OpCompat compat("fc"); + + OpDesc fc_op; + fc_op.SetInput("Input", std::vector{"test_input"}); + + EXPECT_FALSE(compat.Judge(fc_op)); + + compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End(); + EXPECT_FALSE(compat.Judge(fc_op)); + + fc_op.SetInput("Bias", std::vector{"test_input", ""}); + EXPECT_FALSE(compat.Judge(fc_op)); +} + +TEST(OpCompatSensiblePass, compatOutput) { + OpCompat compat("fc"); + + OpDesc fc_op; + fc_op.SetOutput("Output", std::vector{"test_output"}); + + EXPECT_FALSE(compat.Judge(fc_op)); + + compat.AddOutput("Output") + .IsTensor() + .End() + .AddOutput("Output_2") + .IsTensor() + .End(); + EXPECT_FALSE(compat.Judge(fc_op)); + + fc_op.SetOutput("Output_2", std::vector{"test_output", ""}); + EXPECT_FALSE(compat.Judge(fc_op)); +} + class OpCompatSensiblePassTest : public OpCompatSensiblePass { public: OpCompatSensiblePassTest(); @@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { }; OpCompatSensiblePassTest::OpCompatSensiblePassTest() { - AddOpCompat(OpCompat("FC")) + AddOpCompat(OpCompat("fc")) .AddAttr("in_num_col_dims") .IsNumLE(1) .End() @@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() { TEST(OpCompatSensiblePass, IsCompat) { OpCompatSensiblePassTest test; OpDesc fc_op; - fc_op.SetType("FC"); + fc_op.SetType("fc"); std::unordered_map attr_map; attr_map["in_num_col_dims"] = 1; attr_map["activation_type"] = std::string("tanh"); @@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) { fc_op.SetOutput("Out", std::vector{"test_output"}); EXPECT_TRUE(test.TestIsCompat(fc_op)); - - ProgramDesc prog; - std::unique_ptr g(new Graph(prog)); - Node* o1 = g->CreateOpNode(&fc_op); - - GraphPatternDetector detector; - PDNode* op2 = - detector.mutable_pattern()->NewNode([](Node* x) { return true; }); - GraphPatternDetector::subgraph_t subgraph; - subgraph[op2] = o1; - - test.AccessSubgraph(subgraph, g.get()); } } // namespace ir