未验证 提交 dc72ffa5 编写于 作者: 王明冬 提交者: GitHub

add the IsLeftDefault definition for pass enhance,test=develop (#33081)

上级 88dfb30f
......@@ -52,7 +52,7 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PA
cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector)
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
......
......@@ -17,7 +17,7 @@
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
......@@ -46,7 +46,7 @@ enum FuseOptions {
FUSE_MKLDNN // fusing will be done with MKL-DNN
};
class FusePassBase : public Pass {
class FusePassBase : public OpCompatSensiblePass {
public:
void Init(const std::string& repr, Graph* graph) const;
Scope* param_scope() const;
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -51,11 +51,33 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
}
//! Todo: append the definition.
AttrCompat& AttrCompat::IsLeftDefault() { return *this; }
AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) {
VLOG(3) << "Op (" << op_name << ") is not registered!";
conditions_.emplace_back([](const Attribute& attr) { return false; });
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
if (attrs.find(attr_name_) == attrs.end()) {
VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
} else {
Attribute default_attr = attrs.at(attr_name_);
conditions_.emplace_back([default_attr](const Attribute& attr) -> bool {
return attr == default_attr;
});
}
return *this;
}
bool AttrCompat::operator()(const OpDesc& op_desc) {
if (conditions_.empty()) {
return true;
}
if (!op_desc.HasAttr(attr_name_)) {
return false;
return optional_;
}
const Attribute attr = op_desc.GetAttr(attr_name_);
for (auto& func : conditions_) {
......@@ -65,6 +87,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) {
}
return true;
}
AttrCompat& AttrCompat::IsOptional() {
optional_ = true;
return *this;
}
AttrCompat& AttrCompat::IsBoolEQ(bool v) {
conditions_.emplace_back([v](const Attribute& attr) -> bool {
......@@ -98,8 +124,12 @@ bool InputOrOutputCompat::operator()(
}
AttrCompat& OpCompat::AddAttr(const std::string& attr_name) {
attr_compats_.emplace_back(attr_name, this);
return attr_compats_.back();
PADDLE_ENFORCE_EQ(
attr_compats_.find(attr_name), attr_compats_.end(),
platform::errors::InvalidArgument(
"The attrubute compat with the same name has been added"));
attr_compats_.emplace(attr_name, AttrCompat(attr_name, this));
return attr_compats_.at(attr_name);
}
InputOrOutputCompat& OpCompat::AddInput(const std::string& name) {
......@@ -119,8 +149,19 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
}
bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_map : op_desc.GetAttrMap()) {
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not equal to default value!";
return false;
}
}
}
for (auto& attr_compat : attr_compats_) {
if (!attr_compat(op_desc)) {
if (!attr_compat.second(op_desc)) {
VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
return false;
}
}
......@@ -129,6 +170,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_desc : inputs_map) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) {
VLOG(3) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
......@@ -136,10 +179,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& input_val : input_compats_) {
if (inputs_map.find(input_val.first) == inputs_map.end()) {
if (!input_val.second.Optional()) {
VLOG(3) << "The No optional Input (" << input_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false;
}
} else {
if (!input_val.second(inputs_map.at(input_val.first))) {
VLOG(3) << "The Input (" << input_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
......@@ -149,6 +196,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_desc : outputs_map) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) {
VLOG(3) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
return false;
}
}
......@@ -156,10 +205,14 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& output_val : output_compats_) {
if (outputs_map.find(output_val.first) == outputs_map.end()) {
if (!output_val.second.Optional()) {
VLOG(3) << "The No optional Output (" << output_val.first
<< ") of Operator (" << op_name_ << ") not find in op_desc!";
return false;
}
} else {
if (!output_val.second(outputs_map.at(output_val.first))) {
VLOG(3) << "The Output (" << output_val.first << ") of Operator ("
<< op_name_ << ") compat check failed!";
return false;
}
}
......
......@@ -29,7 +29,7 @@ class OpCompat;
class AttrCompat {
public:
AttrCompat(const std::string& attr_name, OpCompat* op_compat)
: attr_name_(attr_name), op_compat_(op_compat) {}
: optional_(false), attr_name_(attr_name), op_compat_(op_compat) {}
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
......@@ -70,12 +70,15 @@ class AttrCompat {
//! Tell whether this attribute is left as default value.
AttrCompat& IsLeftDefault();
AttrCompat& IsOptional();
//! Jump back to retrieve OpCompat instance.
OpCompat& End() { return *op_compat_; }
bool operator()(const OpDesc& op_desc);
private:
bool optional_;
std::string attr_name_;
OpCompat* op_compat_;
std::vector<std::function<bool(const Attribute&)>> conditions_;
......@@ -134,7 +137,7 @@ class OpCompat {
private:
std::string op_name_;
std::vector<AttrCompat> attr_compats_;
std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
};
......@@ -179,15 +182,6 @@ class OpCompat {
* };
*/
class OpCompatSensiblePass : public Pass {
public:
//! Access the subgraph and pattern.
void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (IsCompat(subgraph, g)) {
AccessSubgraphImpl(subgraph, g);
}
}
protected:
/**
* Developer should push the compatibility `teller` for each kind of Op in the
......@@ -197,12 +191,6 @@ class OpCompatSensiblePass : public Pass {
*/
OpCompat& AddOpCompat(OpCompat&& op_compat);
//! Modify the subgraph.
virtual bool AccessSubgraphImpl(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const {
return true;
}
//! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const {
......@@ -212,7 +200,7 @@ class OpCompatSensiblePass : public Pass {
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.first->IsOp()) continue;
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
return false;
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
......@@ -23,7 +23,7 @@ namespace ir {
TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("FC");
OpCompat compat("fc");
compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2})
.IsNumLE(1)
......@@ -67,10 +67,75 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "FC");
EXPECT_STREQ(compat.Name().c_str(), "fc");
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpAttribute) {
OpCompat compat("fc");
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
fc_op.SetAttrMap(attr_map);
OpInfo info;
info.checker_ = new OpAttrChecker();
OpInfoMap::Instance().Insert("fc", info);
EXPECT_FALSE(compat.Judge(fc_op));
info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1);
EXPECT_TRUE(compat.Judge(fc_op));
delete info.checker_;
}
TEST(OpCompatSensiblePass, compatOpAttributeOptional) {
OpCompat compat("fc");
compat.AddAttr("activation_type")
.IsOptional()
.IsStringIn({"tanh", "sigmoid"});
OpDesc fc_op;
EXPECT_TRUE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOpInput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
TEST(OpCompatSensiblePass, compatOutput) {
OpCompat compat("fc");
OpDesc fc_op;
fc_op.SetOutput("Output", std::vector<std::string>{"test_output"});
EXPECT_FALSE(compat.Judge(fc_op));
compat.AddOutput("Output")
.IsTensor()
.End()
.AddOutput("Output_2")
.IsTensor()
.End();
EXPECT_FALSE(compat.Judge(fc_op));
fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""});
EXPECT_FALSE(compat.Judge(fc_op));
}
class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public:
OpCompatSensiblePassTest();
......@@ -78,7 +143,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
};
OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("FC"))
AddOpCompat(OpCompat("fc"))
.AddAttr("in_num_col_dims")
.IsNumLE(1)
.End()
......@@ -102,7 +167,7 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
TEST(OpCompatSensiblePass, IsCompat) {
OpCompatSensiblePassTest test;
OpDesc fc_op;
fc_op.SetType("FC");
fc_op.SetType("fc");
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
......@@ -114,18 +179,6 @@ TEST(OpCompatSensiblePass, IsCompat) {
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op));
ProgramDesc prog;
std::unique_ptr<Graph> g(new Graph(prog));
Node* o1 = g->CreateOpNode(&fc_op);
GraphPatternDetector detector;
PDNode* op2 =
detector.mutable_pattern()->NewNode([](Node* x) { return true; });
GraphPatternDetector::subgraph_t subgraph;
subgraph[op2] = o1;
test.AccessSubgraph(subgraph, g.get());
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册