From 8c463700e171113e603335e05d7c6ff9cb7f1907 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Thu, 2 Apr 2020 04:55:58 +0200 Subject: [PATCH] Add default pass attributes (#23042) --- .../ir/mkldnn/cpu_quantize_placement_pass.cc | 7 ++- .../cpu_quantize_placement_pass_tester.cc | 24 ++++++++++ paddle/fluid/framework/ir/pass.h | 44 ++++++++++++++++- paddle/fluid/framework/ir/pass_test.cc | 48 +++++++++++++++++++ 4 files changed, 119 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc index 2ccd4062214..d570d885c6e 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc @@ -51,6 +51,9 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(cpu_quantize_placement_pass, paddle::framework::ir::CPUQuantizePlacementPass) // a vector of operator type names to be quantized ("conv2d" etc.) - .RequirePassAttr("quantize_enabled_op_types") + // the second param is the default value for this vector + .DefaultPassAttr("quantize_enabled_op_types", + new std::unordered_set()) // a vector of operator ids that are to be excluded from quantization - .RequirePassAttr("quantize_excluded_op_ids"); + // the second param is the default value for this vector + .DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set()); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc index ba4d281f818..479d3087ba7 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc @@ -111,6 +111,25 @@ void MainTest(std::initializer_list quantize_enabled_op_types, EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count); } +void DefaultAttrTest(unsigned expected_use_quantizer_true_count) { + auto prog = BuildProgramDesc(); + std::unique_ptr graph(new ir::Graph(prog)); + auto pass = PassRegistry::Instance().Get("cpu_quantize_placement_pass"); + graph.reset(pass->Apply(graph.release())); + + unsigned use_quantizer_true_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->HasAttr("use_quantizer") && + boost::get(op->GetAttr("use_quantizer"))) { + ++use_quantizer_true_count; + } + } + } + EXPECT_EQ(use_quantizer_true_count, expected_use_quantizer_true_count); +} + TEST(QuantizerPlacementPass, enabled_pool) { MainTest({"pool2d"}, {}, 2); } TEST(QuantizerPlacementPass, enabled_conv_excluded_one) { @@ -122,6 +141,11 @@ TEST(QuantizerPlacementPass, excluded_none) { MainTest({}, {}, 4); } +TEST(QuantizerPlacementPass, default_attr_value) { + // 2 conv + 2 pool + DefaultAttrTest(4); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index cef7feeadf2..509c3021d8c 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -100,8 +100,14 @@ class Pass { // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", - attr_name); + if (default_pass_attrs_.count(attr_name) == 0) { + PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0, + platform::errors::InvalidArgument( + "Attribute %s already set in the pass", attr_name)); + } else { + VLOG(3) << "Setting the attribute " << attr_name << " for the pass " + << type_; + } attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; @@ -140,11 +146,21 @@ class Pass { required_graph_attrs_.insert(attrs.begin(), attrs.end()); } + // Pass doesn't take ownership. PassRegistrar should delete default_attrs + void RegisterDefaultPassAttrs( + std::map default_attr_values) { + for (auto const &attr_name : default_attr_values) { + default_pass_attrs_.insert(attr_name.first); + } + attrs_.insert(default_attr_values.begin(), default_attr_values.end()); + } + void RegisterType(const std::string &type) { type_ = type; } mutable bool applied_{false}; std::string type_; std::unordered_set required_pass_attrs_; + std::unordered_set default_pass_attrs_; std::unordered_set required_graph_attrs_; std::map attrs_; std::map> attr_dels_; @@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar { std::unique_ptr pass(new PassType()); pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + pass->RegisterDefaultPassAttrs(this->default_attr_values_); pass->RegisterType(pass_type); return pass; }); } + ~PassRegistrar() { + for (auto &attr : default_attr_values_) { + if (default_attr_dels_.find(attr.first) != default_attr_dels_.end()) { + default_attr_dels_[attr.first](); + } + } + default_attr_values_.clear(); + default_attr_dels_.clear(); + } + PassRegistrar &RequirePassAttr(const std::string &attr) { required_pass_attrs_.insert(attr); return *this; } + // PassRegistrar takes ownership of default_attr_value + template + PassRegistrar &DefaultPassAttr(const std::string &attr, + AttrType &&default_attr_value) { + default_attr_values_[attr] = default_attr_value; + default_attr_dels_[attr] = [default_attr_value, attr]() { + delete default_attr_value; + }; + return *this; + } + PassRegistrar &RequireGraphAttr(const std::string &attr) { required_graph_attrs_.insert(attr); return *this; @@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar { private: std::unordered_set required_pass_attrs_; std::unordered_set required_graph_attrs_; + std::map default_attr_values_; + std::map> default_attr_dels_; }; #define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 0e73b2f77d0..14e94a2bc5c 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -120,6 +120,50 @@ TEST(PassTest, TestPassAttrCheck) { exception = std::string(e.what()); } ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->Set("test_pass_attr", new int); + try { + pass->Set("test_pass_attr", new int); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE( + exception.find("Attribute test_pass_attr already set in the pass") != + exception.npos); +} + +class TestPassWithDefault : public Pass { + protected: + void ApplyImpl(ir::Graph* graph) const { + graph->Set("copy_default_attr", new int); + + int test_pass_attr = this->Get("default_attr"); + graph->Get("copy_default_attr") = test_pass_attr + 1; + } +}; + +TEST(PassTest, TestPassDefaultAttrCheck) { + ProgramDesc prog; + // check if default value is set + auto pass = PassRegistry::Instance().Get("test_pass_default_attr"); + std::unique_ptr graph(new Graph(prog)); + ASSERT_EQ(pass->Get("default_attr"), 1); + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_default_attr"), 2); + + // check if new value overrides default value + pass = PassRegistry::Instance().Get("test_pass_default_attr"); + pass->Set("default_attr", new int{3}); + ASSERT_EQ(pass->Get("default_attr"), 3); +} + +TEST(PassTest, TestPassRegistrarDeconstructor) { + auto pass_registrary = + new PassRegistrar( + "test_deconstructor"); + pass_registrary->DefaultPassAttr("deconstructor_attr", new int{1}); + pass_registrary->~PassRegistrar(); } } // namespace ir @@ -129,3 +173,7 @@ TEST(PassTest, TestPassAttrCheck) { REGISTER_PASS(test_pass, paddle::framework::ir::TestPass) .RequirePassAttr("test_pass_attr") .RequireGraphAttr("test_graph_attr"); + +REGISTER_PASS(test_pass_default_attr, + paddle::framework::ir::TestPassWithDefault) + .DefaultPassAttr("default_attr", new int{1}); -- GitLab