未验证 提交 8c463700 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add default pass attributes (#23042)

上级 48144e40
...@@ -51,6 +51,9 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -51,6 +51,9 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(cpu_quantize_placement_pass, REGISTER_PASS(cpu_quantize_placement_pass,
paddle::framework::ir::CPUQuantizePlacementPass) paddle::framework::ir::CPUQuantizePlacementPass)
// a vector of operator type names to be quantized ("conv2d" etc.) // a vector of operator type names to be quantized ("conv2d" etc.)
.RequirePassAttr("quantize_enabled_op_types") // the second param is the default value for this vector
.DefaultPassAttr("quantize_enabled_op_types",
new std::unordered_set<std::string>())
// a vector of operator ids that are to be excluded from quantization // a vector of operator ids that are to be excluded from quantization
.RequirePassAttr("quantize_excluded_op_ids"); // the second param is the default value for this vector
.DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>());
...@@ -111,6 +111,25 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types, ...@@ -111,6 +111,25 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count); EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count);
} }
void DefaultAttrTest(unsigned expected_use_quantizer_true_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_quantize_placement_pass");
graph.reset(pass->Apply(graph.release()));
unsigned use_quantizer_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->HasAttr("use_quantizer") &&
boost::get<bool>(op->GetAttr("use_quantizer"))) {
++use_quantizer_true_count;
}
}
}
EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count);
}
TEST(QuantizerPlacementPass, enabled_pool) { MainTest({"pool2d"}, {}, 2); } TEST(QuantizerPlacementPass, enabled_pool) { MainTest({"pool2d"}, {}, 2); }
TEST(QuantizerPlacementPass, enabled_conv_excluded_one) { TEST(QuantizerPlacementPass, enabled_conv_excluded_one) {
...@@ -122,6 +141,11 @@ TEST(QuantizerPlacementPass, excluded_none) { ...@@ -122,6 +141,11 @@ TEST(QuantizerPlacementPass, excluded_none) {
MainTest({}, {}, 4); MainTest({}, {}, 4);
} }
TEST(QuantizerPlacementPass, default_attr_value) {
// 2 conv + 2 pool
DefaultAttrTest(4);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -100,8 +100,14 @@ class Pass { ...@@ -100,8 +100,14 @@ class Pass {
// Set a pointer to the attribute. Pass takes ownership of the attribute. // Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", if (default_pass_attrs_.count(attr_name) == 0) {
attr_name); PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
platform::errors::InvalidArgument(
"Attribute %s already set in the pass", attr_name));
} else {
VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
<< type_;
}
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(3) << "deleting " << attr_name;
...@@ -140,11 +146,21 @@ class Pass { ...@@ -140,11 +146,21 @@ class Pass {
required_graph_attrs_.insert(attrs.begin(), attrs.end()); required_graph_attrs_.insert(attrs.begin(), attrs.end());
} }
// Pass doesn't take ownership. PassRegistrar should delete default_attrs
void RegisterDefaultPassAttrs(
std::map<std::string, boost::any> default_attr_values) {
for (auto const &attr_name : default_attr_values) {
default_pass_attrs_.insert(attr_name.first);
}
attrs_.insert(default_attr_values.begin(), default_attr_values.end());
}
void RegisterType(const std::string &type) { type_ = type; } void RegisterType(const std::string &type) { type_ = type; }
mutable bool applied_{false}; mutable bool applied_{false};
std::string type_; std::string type_;
std::unordered_set<std::string> required_pass_attrs_; std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> default_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_; std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
...@@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar { ...@@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar {
std::unique_ptr<Pass> pass(new PassType()); std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
pass->RegisterDefaultPassAttrs(this->default_attr_values_);
pass->RegisterType(pass_type); pass->RegisterType(pass_type);
return pass; return pass;
}); });
} }
~PassRegistrar() {
for (auto &attr : default_attr_values_) {
if (default_attr_dels_.find(attr.first) != default_attr_dels_.end()) {
default_attr_dels_[attr.first]();
}
}
default_attr_values_.clear();
default_attr_dels_.clear();
}
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) { PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr); required_pass_attrs_.insert(attr);
return *this; return *this;
} }
// PassRegistrar takes ownership of default_attr_value
template <typename AttrType>
PassRegistrar<PassType> &DefaultPassAttr(const std::string &attr,
AttrType &&default_attr_value) {
default_attr_values_[attr] = default_attr_value;
default_attr_dels_[attr] = [default_attr_value, attr]() {
delete default_attr_value;
};
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) { PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr); required_graph_attrs_.insert(attr);
return *this; return *this;
...@@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar { ...@@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar {
private: private:
std::unordered_set<std::string> required_pass_attrs_; std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_; std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> default_attr_values_;
std::map<std::string, std::function<void(void)>> default_attr_dels_;
}; };
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ #define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
......
...@@ -120,6 +120,50 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -120,6 +120,50 @@ TEST(PassTest, TestPassAttrCheck) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos); ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int);
try {
pass->Set<int>("test_pass_attr", new int);
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(
exception.find("Attribute test_pass_attr already set in the pass") !=
exception.npos);
}
class TestPassWithDefault : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const {
graph->Set<int>("copy_default_attr", new int);
int test_pass_attr = this->Get<int>("default_attr");
graph->Get<int>("copy_default_attr") = test_pass_attr + 1;
}
};
TEST(PassTest, TestPassDefaultAttrCheck) {
ProgramDesc prog;
// check if default value is set
auto pass = PassRegistry::Instance().Get("test_pass_default_attr");
std::unique_ptr<Graph> graph(new Graph(prog));
ASSERT_EQ(pass->Get<int>("default_attr"), 1);
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_default_attr"), 2);
// check if new value overrides default value
pass = PassRegistry::Instance().Get("test_pass_default_attr");
pass->Set<int>("default_attr", new int{3});
ASSERT_EQ(pass->Get<int>("default_attr"), 3);
}
TEST(PassTest, TestPassRegistrarDeconstructor) {
auto pass_registrary =
new PassRegistrar<paddle::framework::ir::TestPassWithDefault>(
"test_deconstructor");
pass_registrary->DefaultPassAttr("deconstructor_attr", new int{1});
pass_registrary->~PassRegistrar();
} }
} // namespace ir } // namespace ir
...@@ -129,3 +173,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -129,3 +173,7 @@ TEST(PassTest, TestPassAttrCheck) {
REGISTER_PASS(test_pass, paddle::framework::ir::TestPass) REGISTER_PASS(test_pass, paddle::framework::ir::TestPass)
.RequirePassAttr("test_pass_attr") .RequirePassAttr("test_pass_attr")
.RequireGraphAttr("test_graph_attr"); .RequireGraphAttr("test_graph_attr");
REGISTER_PASS(test_pass_default_attr,
paddle::framework::ir::TestPassWithDefault)
.DefaultPassAttr("default_attr", new int{1});
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册