未验证 提交 dc72ffa5 编写于 作者: 王明冬 提交者: GitHub

add the IsLeftDefault definition for pass enhance,test=develop (#33081)

上级 88dfb30f
...@@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA ...@@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector) cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
...@@ -46,7 +46,7 @@ enum FuseOptions { ...@@ -46,7 +46,7 @@ enum FuseOptions {
FUSE_MKLDNN // fusing will be done with MKL-DNN FUSE_MKLDNN // fusing will be done with MKL-DNN
}; };
class FusePassBase : public Pass { class FusePassBase : public OpCompatSensiblePass {
public: public:
void Init(const std::string& repr, Graph* graph) const; void Init(const std::string& repr, Graph* graph) const;
Scope* param_scope() const; Scope* param_scope() const;
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) { ...@@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
} }
//! Todo: append the definition. //! Todo: append the definition.
AttrCompat& AttrCompat::IsLeftDefault() { return *this; } AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) {
VLOG(3) << "Op (" << op_name << ") is not registered!";
conditions_.emplace_back([](const Attribute& attr) { return false; });
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
if (attrs.find(attr_name_) == attrs.end()) {
VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
} else {
Attribute default_attr = attrs.at(attr_name_);
conditions_.emplace_back([default_attr](const Attribute& attr) -> bool {
return attr == default_attr;
});
}
return *this;
}
bool AttrCompat::operator()(const OpDesc& op_desc) { bool AttrCompat::operator()(const OpDesc& op_desc) {
if (conditions_.empty()) {
return true;
}
if (!op_desc.HasAttr(attr_name_)) { if (!op_desc.HasAttr(attr_name_)) {
return false; return optional_;
} }
const Attribute attr = op_desc.GetAttr(attr_name_); const Attribute attr = op_desc.GetAttr(attr_name_);
for (auto& func : conditions_) { for (auto& func : conditions_) {
...@@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) { ...@@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
} }
return true; return true;
} }
AttrCompat& AttrCompat::IsOptional() {
optional_ = true;
return *this;
}
AttrCompat& AttrCompat::IsBoolEQ(bool v) { AttrCompat& AttrCompat::IsBoolEQ(bool v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool { conditions_.emplace_back([v](const Attribute& attr) -> bool {
...@@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()( ...@@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()(
} }
AttrCompat& OpCompat::AddAttr(const std::string& attr_name) { AttrCompat& OpCompat::AddAttr(const std::string& attr_name) {
attr_compats_.emplace_back(attr_name, this); PADDLE_ENFORCE_EQ(
return attr_compats_.back(); attr_compats_.find(attr_name), attr_compats_.end(),
platform::errors::InvalidArgument(
"The attrubute compat with the same name has been added"));
attr_compats_.emplace(attr_name, AttrCompat(attr_name, this));
return attr_compats_.at(attr_name);
} }
InputOrOutputCompat& OpCompat::AddInput(const std::string& name) { InputOrOutputCompat& OpCompat::AddInput(const std::string& name) {
...@@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { ...@@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
} }
bool OpCompat::Judge(const OpDesc& op_desc) { bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_map : op_desc.GetAttrMap()) {
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not equal to default value!";
return false;
}
}
}
for (auto& attr_compat : attr_compats_) { for (auto& attr_compat : attr_compats_) {
if (!attr_compat(op_desc)) { if (!attr_compat.second(op_desc)) {
VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
return false; return false;
} }
} }
...@@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { ...@@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_desc : inputs_map) { for (auto& input_desc : inputs_map) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) { if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) { if (!input_desc.second.empty()) {
VLOG(3) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false; return false;
} }
} }
...@@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) { ...@@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_val : input_compats_) { for (auto& input_val : input_compats_) {
if (inputs_map.find(input_val.first) == inputs_map.end()) { if (inputs_map.find(input_val.first) == inputs_map.end()) {
if (!input_val.second.Optional()) { if (!input_val.second.Optional()) {
VLOG(3) << "The No optional Input (" << input_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false; return false;
} }
} else { } else {
if (!input_val.second(inputs_map.at(input_val.first))) { if (!input_val.second(inputs_map.at(input_val.first))) {
VLOG(3) << "The Input (" << input_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false; return false;
} }
} }
...@@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { ...@@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_desc : outputs_map) { for (auto& output_desc : outputs_map) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) { if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) { if (!output_desc.second.empty()) {
VLOG(3) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false; return false;
} }
} }
...@@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) { ...@@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_val : output_compats_) { for (auto& output_val : output_compats_) {
if (outputs_map.find(output_val.first) == outputs_map.end()) { if (outputs_map.find(output_val.first) == outputs_map.end()) {
if (!output_val.second.Optional()) { if (!output_val.second.Optional()) {
VLOG(3) << "The No optional Output (" << output_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false; return false;
} }
} else { } else {
if (!output_val.second(outputs_map.at(output_val.first))) { if (!output_val.second(outputs_map.at(output_val.first))) {
VLOG(3) << "The Output (" << output_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false; return false;
} }
} }
......
...@@ -29,7 +29,7 @@ class OpCompat; ...@@ -29,7 +29,7 @@ class OpCompat;
class AttrCompat { class AttrCompat {
public: public:
AttrCompat(const std::string& attr_name, OpCompat* op_compat) AttrCompat(const std::string& attr_name, OpCompat* op_compat)
: attr_name_(attr_name), op_compat_(op_compat) {} : optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
// @{ String-related methods // @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain. //! Assert the attribute is an string in the `candidates` domain.
...@@ -70,12 +70,15 @@ class AttrCompat { ...@@ -70,12 +70,15 @@ class AttrCompat {
//! Tell whether this attribute is left as default value. //! Tell whether this attribute is left as default value.
AttrCompat& IsLeftDefault(); AttrCompat& IsLeftDefault();
AttrCompat& IsOptional();
//! Jump back to retrieve OpCompat instance. //! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; } OpCompat& End() { return *op_compat_; }
bool operator()(const OpDesc& op_desc); bool operator()(const OpDesc& op_desc);
private: private:
bool optional_;
std::string attr_name_; std::string attr_name_;
OpCompat* op_compat_; OpCompat* op_compat_;
std::vector<std::function<bool(const Attribute&)>> conditions_; std::vector<std::function<bool(const Attribute&)>> conditions_;
...@@ -134,7 +137,7 @@ class OpCompat { ...@@ -134,7 +137,7 @@ class OpCompat {
private: private:
std::string op_name_; std::string op_name_;
std::vector<AttrCompat> attr_compats_; std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_; std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_; std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
}; };
...@@ -179,15 +182,6 @@ class OpCompat { ...@@ -179,15 +182,6 @@ class OpCompat {
* }; * };
*/ */
class OpCompatSensiblePass : public Pass { class OpCompatSensiblePass : public Pass {
public:
//! Access the subgraph and pattern.
void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (IsCompat(subgraph, g)) {
AccessSubgraphImpl(subgraph, g);
}
}
protected: protected:
/** /**
* Developer should push the compatibility `teller` for each kind of Op in the * Developer should push the compatibility `teller` for each kind of Op in the
...@@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass { ...@@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass {
*/ */
OpCompat& AddOpCompat(OpCompat&& op_compat); OpCompat& AddOpCompat(OpCompat&& op_compat);
//! Modify the subgraph.
virtual bool AccessSubgraphImpl(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const {
return true;
}
//! Tell the Op compability of a subgraph. //! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const { Graph* g) const {
...@@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass { ...@@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass {
// Check the all the ops in the subgraph are contained in the // Check the all the ops in the subgraph are contained in the
// op_compat. // op_compat.
for (auto& node_pair : subgraph) { for (auto& node_pair : subgraph) {
if (!node_pair.first->IsOp()) continue; if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type(); auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) { if (!op_compat_judgers_.count(op_type)) {
return false; return false;
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
TEST(OpCompatSensiblePass, compatOp) { TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; }; auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("FC"); OpCompat compat("fc");
compat.AddAttr("in_num_col_dims") compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2}) .IsIntIn({1, 2})
.IsNumLE(1) .IsNumLE(1)
...@@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) { ...@@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"}); fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"}); fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "FC"); EXPECT_STREQ(compat.Name().c_str(), "fc");
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpAttribute) {
OpCompat compat("fc");
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
fc_op.SetAttrMap(attr_map);
OpInfo info;
info.checker_ = new OpAttrChecker();
OpInfoMap::Instance().Insert("fc", info);
EXPECT_FALSE(compat.Judge(fc_op));
info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1);
EXPECT_TRUE(compat.Judge(fc_op));
delete info.checker_;
}
TEST(OpCompatSensiblePass, compatOpAttributeOptional) {
OpCompat compat("fc");
compat.AddAttr("activation_type")
.IsOptional()
.IsStringIn({"tanh", "sigmoid"});
OpDesc fc_op;
EXPECT_TRUE(compat.Judge(fc_op)); EXPECT_TRUE(compat.Judge(fc_op));
} }
TEST(OpCompatSensiblePass, compatOpInput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOutput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetOutput("Output", std::vector<std::string>{"test_output"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddOutput("Output")
.IsTensor()
.End()
.AddOutput("Output_2")
.IsTensor()
.End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
class OpCompatSensiblePassTest : public OpCompatSensiblePass { class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public: public:
OpCompatSensiblePassTest(); OpCompatSensiblePassTest();
...@@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { ...@@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
}; };
OpCompatSensiblePassTest::OpCompatSensiblePassTest() { OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("FC")) AddOpCompat(OpCompat("fc"))
.AddAttr("in_num_col_dims") .AddAttr("in_num_col_dims")
.IsNumLE(1) .IsNumLE(1)
.End() .End()
...@@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() { ...@@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
TEST(OpCompatSensiblePass, IsCompat) { TEST(OpCompatSensiblePass, IsCompat) {
OpCompatSensiblePassTest test; OpCompatSensiblePassTest test;
OpDesc fc_op; OpDesc fc_op;
fc_op.SetType("FC"); fc_op.SetType("fc");
std::unordered_map<std::string, Attribute> attr_map; std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1; attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh"); attr_map["activation_type"] = std::string("tanh");
...@@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) { ...@@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) {
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"}); fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op)); EXPECT_TRUE(test.TestIsCompat(fc_op));
ProgramDesc prog;
std::unique_ptr<Graph> g(new Graph(prog));
Node* o1 = g->CreateOpNode(&fc_op);
GraphPatternDetector detector;
PDNode* op2 =
detector.mutable_pattern()->NewNode([](Node* x) { return true; });
GraphPatternDetector::subgraph_t subgraph;
subgraph[op2] = o1;
test.AccessSubgraph(subgraph, g.get());
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册