From fc5b3a9942e5f62da26e58e1cdb44d354df1479b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 3 Jun 2021 11:15:44 +0800 Subject: [PATCH] add the fc fuse example for pass enhance, test=develop (#33250) --- paddle/fluid/framework/CMakeLists.txt | 5 + paddle/fluid/framework/ir/CMakeLists.txt | 2 +- paddle/fluid/framework/ir/fc_fuse_pass.cc | 70 ++++++++++++- paddle/fluid/framework/ir/fc_fuse_pass.h | 1 + .../framework/ir/op_compat_sensible_pass.cc | 66 +++++++++---- .../framework/ir/op_compat_sensible_pass.h | 3 + .../ir/op_compat_sensible_pass_tester.cc | 3 +- paddle/fluid/framework/op_def_api.cc | 64 ++++++++++++ paddle/fluid/framework/op_def_api.h.in | 12 +++ .../operators/compat/elementwise_add.pbtxt | 73 ++++++++++++++ paddle/fluid/operators/compat/fc.pbtxt | 97 +++++++++++++++++++ paddle/fluid/operators/compat/mul.pbtxt | 87 +++++++++++++++++ paddle/fluid/operators/compat/relu.pbtxt | 43 ++++++++ 13 files changed, 501 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/framework/op_def_api.cc create mode 100644 paddle/fluid/framework/op_def_api.h.in create mode 100644 paddle/fluid/operators/compat/elementwise_add.pbtxt create mode 100644 paddle/fluid/operators/compat/fc.pbtxt create mode 100644 paddle/fluid/operators/compat/mul.pbtxt create mode 100644 paddle/fluid/operators/compat/relu.pbtxt diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8d1ae4926a8..f39c16002dd 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -27,7 +27,12 @@ add_subdirectory(fleet) add_subdirectory(io) #ddim lib proto_library(framework_proto SRCS framework.proto) + proto_library(op_def_proto SRCS op_def.proto) +set(OP_DEF_FOLDER "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/compat/") +configure_file("op_def_api.h.in" "op_def_api.h") +cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto) + proto_library(heter_service_proto SRCS heter_service.proto) proto_library(data_feed_proto SRCS data_feed.proto) proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index fb478bb6e89..16dfc90d27e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -50,7 +50,7 @@ if (WITH_TESTING) endif(WITH_TESTING) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS}) -cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector) +cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector op_def_api) cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS op_compat_sensible_pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index bc1be79d1b1..656d453d403 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fc_fuse_pass.h" - #include +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -23,6 +23,65 @@ namespace paddle { namespace framework { namespace ir { +FCFusePass::FCFusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); +} + void FCFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -52,6 +111,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { LOG(WARNING) << "The subgraph is empty."; return; } + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle FC fuse"; GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); @@ -159,6 +222,11 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { } desc.Flush(); + if (!IsCompat(desc)) { + LOG(WARNING) << "Fc fuse pass in out fc op compat failed."; + return; + } + auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. if (with_relu) { GraphSafeRemoveNodes( diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index f564bbb1518..21ef17b65dc 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -30,6 +30,7 @@ class Graph; class FCFusePass : public FusePassBase { public: + FCFusePass(); virtual ~FCFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index b056c3b07a2..3d8e655c5b2 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/op_def_api.h" #include "paddle/fluid/framework/op_info.h" + namespace paddle { namespace framework { namespace ir { @@ -50,18 +53,17 @@ AttrCompat& AttrCompat::IsIntIn(const std::set& candidates) { return *this; } -//! Todo: append the definition. AttrCompat& AttrCompat::IsLeftDefault() { const std::string& op_name = op_compat_->Name(); if (!OpInfoMap::Instance().Has(op_name)) { - VLOG(3) << "Op (" << op_name << ") is not registered!"; + LOG(WARNING) << "Op (" << op_name << ") is not registered!"; conditions_.emplace_back([](const Attribute& attr) { return false; }); return *this; } const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap(); if (attrs.find(attr_name_) == attrs.end()) { - VLOG(3) << "Op (" << op_name << ") has no default attr:" << attr_name_; + LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; conditions_.emplace_back([](const Attribute& attr) { return false; }); } else { Attribute default_attr = attrs.at(attr_name_); @@ -77,6 +79,10 @@ bool AttrCompat::operator()(const OpDesc& op_desc) { return true; } if (!op_desc.HasAttr(attr_name_)) { + if (!optional_) { + LOG(WARNING) << "The non-optional Attr(" << attr_name_ << ") of Op (" + << op_compat_->Name() << ") not find ! "; + } return optional_; } const Attribute attr = op_desc.GetAttr(attr_name_); @@ -149,19 +155,35 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { } bool OpCompat::Judge(const OpDesc& op_desc) { + if (is_first_judge_) { + is_first_judge_ = false; + const proto::OpDef& op_def = GetOpDef(op_name_); + if (op_def.has_extra()) { + for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) { + extra_attrs_.emplace(attr.name()); + } + } + } + for (auto& attr_map : op_desc.GetAttrMap()) { if (attr_compats_.find(attr_map.first) == attr_compats_.end()) { + if (extra_attrs_.find(attr_map.first) != extra_attrs_.end()) { + continue; + } if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) { - VLOG(3) << "The Attr(" << attr_map.first << ") of Op (" << op_name_ - << ") not reigistered in OpCompat, not equal to default value!"; + LOG(WARNING) + << "The Attr(" << attr_map.first << ") of Op (" << op_name_ + << ") not reigistered in OpCompat, not in extra attribute, not " + "equal to default value!"; return false; } } } + for (auto& attr_compat : attr_compats_) { if (!attr_compat.second(op_desc)) { - VLOG(3) << " Check the Attr(" << attr_compat.first << ") of Op(" - << op_name_ << ") failed!"; + LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op(" + << op_name_ << ") failed!"; return false; } } @@ -170,8 +192,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& input_desc : inputs_map) { if (input_compats_.find(input_desc.first) == input_compats_.end()) { if (!input_desc.second.empty()) { - VLOG(3) << "The Input (" << input_desc.first << ") of Operator (" - << op_name_ << ") not reigistered in OpCompat!"; + LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator (" + << op_name_ << ") not reigistered in OpCompat!"; return false; } } @@ -179,14 +201,15 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& input_val : input_compats_) { if (inputs_map.find(input_val.first) == inputs_map.end()) { if (!input_val.second.Optional()) { - VLOG(3) << "The No optional Input (" << input_val.first - << ") of Operator (" << op_name_ << ") not find in op_desc!"; + LOG(WARNING) << "The No optional Input (" << input_val.first + << ") of Operator (" << op_name_ + << ") not find in op_desc!"; return false; } } else { if (!input_val.second(inputs_map.at(input_val.first))) { - VLOG(3) << "The Input (" << input_val.first << ") of Operator (" - << op_name_ << ") compat check failed!"; + LOG(WARNING) << "The Input (" << input_val.first << ") of Operator (" + << op_name_ << ") compat check failed!"; return false; } } @@ -196,8 +219,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& output_desc : outputs_map) { if (output_compats_.find(output_desc.first) == output_compats_.end()) { if (!output_desc.second.empty()) { - VLOG(3) << "The Output (" << output_desc.first << ") of Operator (" - << op_name_ << ") not reigistered in OpCompat!"; + LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator (" + << op_name_ << ") not reigistered in OpCompat!"; return false; } } @@ -205,14 +228,15 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& output_val : output_compats_) { if (outputs_map.find(output_val.first) == outputs_map.end()) { if (!output_val.second.Optional()) { - VLOG(3) << "The No optional Output (" << output_val.first - << ") of Operator (" << op_name_ << ") not find in op_desc!"; + LOG(WARNING) << "The No optional Output (" << output_val.first + << ") of Operator (" << op_name_ + << ") not find in op_desc!"; return false; } } else { if (!output_val.second(outputs_map.at(output_val.first))) { - VLOG(3) << "The Output (" << output_val.first << ") of Operator (" - << op_name_ << ") compat check failed!"; + LOG(WARNING) << "The Output (" << output_val.first << ") of Operator (" + << op_name_ << ") compat check failed!"; return false; } } diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 3f2ea673d87..3aa985c6d46 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -140,6 +140,8 @@ class OpCompat { std::unordered_map attr_compats_; std::unordered_map input_compats_; std::unordered_map output_compats_; + std::unordered_set extra_attrs_; + bool is_first_judge_ = true; }; /** @@ -203,6 +205,7 @@ class OpCompatSensiblePass : public Pass { if (!node_pair.second->IsOp()) continue; auto op_type = node_pair.second->Op()->Type(); if (!op_compat_judgers_.count(op_type)) { + LOG(WARNING) << op_type << "compat not registered!"; return false; } auto& judger = *op_compat_judgers_.at(op_type); diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc index 0878e4d9890..598b686c790 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -27,7 +27,6 @@ TEST(OpCompatSensiblePass, compatOp) { compat.AddAttr("in_num_col_dims") .IsIntIn({1, 2}) .IsNumLE(1) - .IsLeftDefault() .End() .AddAttr("activation_type") .IsStringIn({"tanh", "sigmoid"}) @@ -68,7 +67,7 @@ TEST(OpCompatSensiblePass, compatOp) { fc_op.SetOutput("Out", std::vector{"test_output"}); EXPECT_STREQ(compat.Name().c_str(), "fc"); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_TRUE(compat.Judge(fc_op)); } TEST(OpCompatSensiblePass, compatOpAttribute) { diff --git a/paddle/fluid/framework/op_def_api.cc b/paddle/fluid/framework/op_def_api.cc new file mode 100644 index 00000000000..d8aeb23c63e --- /dev/null +++ b/paddle/fluid/framework/op_def_api.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif +#include "paddle/fluid/framework/op_def_api.h" +#include +#include +#include +#include +#ifdef _LINUX +#include +#include +#include +#endif +#include +#include +#include "glog/logging.h" +#include "io/fs.h" +#include "paddle/fluid/framework/op_def.pb.h" + +namespace paddle { +namespace framework { + +const proto::OpDef& GetOpDef(const std::string& op_name) { + static std::unordered_map ops_definition; + static std::mutex mtx; + if (ops_definition.find(op_name) == ops_definition.end()) { + std::lock_guard lk(mtx); + if (ops_definition.find(op_name) == ops_definition.end()) { + proto::OpDef op_def; + std::string op_path = OP_DEF_FOLDER + op_name + ".pbtxt"; + int fd = open(op_path.c_str(), O_RDONLY); + if (fd == -1) { + LOG(WARNING) << op_path << " open failed!"; + } else { + ::google::protobuf::io::FileInputStream* input = + new ::google::protobuf::io::FileInputStream(fd); + if (!::google::protobuf::TextFormat::Parse(input, &op_def)) { + LOG(WARNING) << "Failed to parse " << op_path; + } + delete input; + close(fd); + } + ops_definition.emplace(std::make_pair(op_name, std::move(op_def))); + } + } + return ops_definition.at(op_name); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/op_def_api.h.in b/paddle/fluid/framework/op_def_api.h.in new file mode 100644 index 00000000000..7a48c487709 --- /dev/null +++ b/paddle/fluid/framework/op_def_api.h.in @@ -0,0 +1,12 @@ +// the folder of pbtxt with op attribute definition +#pragma once + +#include "paddle/fluid/framework/op_def.pb.h" + +#define OP_DEF_FOLDER "@OP_DEF_FOLDER@" + +namespace paddle { +namespace framework { + const proto::OpDef& GetOpDef(const std::string& op_name); +} +} diff --git a/paddle/fluid/operators/compat/elementwise_add.pbtxt b/paddle/fluid/operators/compat/elementwise_add.pbtxt new file mode 100644 index 00000000000..3e96147ef88 --- /dev/null +++ b/paddle/fluid/operators/compat/elementwise_add.pbtxt @@ -0,0 +1,73 @@ +type: "elementwise_add" +def { + inputs { + name: "X" + } + inputs { + name: "Y" + } + outputs { + name: "Out" + } + attrs { + name: "axis" + type: INT + } +} +extra { + attrs { + name: "x_data_format" + type: STRING + # no longer to use + } + attrs { + name: "y_data_format" + type: STRING + # no longer to use + } + attrs { + name: "use_quantizer" + type: BOOLEAN + # no longer to use, Use 'mkldnn_data_type' instead. + } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "mkldnn_data_type" + type: STRING + } + attrs { + name: "Scale_x" + type: FLOAT + } + attrs { + name: "Scale_y" + type: FLOAT + } + attrs { + name: "Scale_out" + type: FLOAT + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/fc.pbtxt b/paddle/fluid/operators/compat/fc.pbtxt new file mode 100644 index 00000000000..55e1a22ce4d --- /dev/null +++ b/paddle/fluid/operators/compat/fc.pbtxt @@ -0,0 +1,97 @@ +type: "fc" +def { + inputs { + name: "Input" + } + inputs { + name: "W" + } + inputs { + name: "Bias" + } + outputs { + name: "Out" + } + attrs { + name: "in_num_col_dims" + type: INT + } + attrs { + name: "activation_type" + type: STRING + } +} +extra { + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "padding_weights" + type: BOOLEAN + } + attrs { + name: "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@" + type: BOOLEAN + } + attrs { + name: "use_quantizer" + type: BOOLEAN + } + attrs { + name: "mkldnn_data_type" + type: STRING + } + attrs { + name: "weight_scale" + type: FLOATS + } + attrs { + name: "Input_scale" + type: FLOAT + } + attrs { + name: "out_scale" + type: FLOAT + } + attrs { + name: "out_threshold" + type: FLOAT + } + attrs { + name: "force_fp32_output" + type: BOOLEAN + } + attrs { + name: "enable_int8" + type: BOOLEAN + } + attrs { + name: "use_fc_padding" + type: BOOLEAN + } + attrs { + name: "use_gpu" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/mul.pbtxt b/paddle/fluid/operators/compat/mul.pbtxt new file mode 100644 index 00000000000..b40c05ad2e0 --- /dev/null +++ b/paddle/fluid/operators/compat/mul.pbtxt @@ -0,0 +1,87 @@ +type: "mul" +def { + inputs { + name: "X" + } + inputs { + name: "Y" + } + outputs { + name: "Out" + } + attrs { + name: "x_num_col_dims" + type: INT + } + attrs { + name: "y_num_col_dims" + type: INT + } +} +extra { + attrs { + name: "skip_quant" + type: BOOLEAN + } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "scale_x" + type: FLOAT + } + attrs { + name: "scale_y" + type: FLOATS + } + attrs { + name: "scale_out" + type: FLOAT + } + attrs { + name: "force_fp32_output" + type: BOOLEAN + } + attrs { + name: "enable_int8" + type: BOOLEAN + } + attrs { + name: "X_scale" + type: FLOAT + } + attrs { + name: "weight_scale" + type: FLOAT + } + attrs { + name: "out_scale" + type: FLOAT + } + attrs { + name: "out_threshold" + type: FLOAT + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } + +} diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt new file mode 100644 index 00000000000..359bd70c2a3 --- /dev/null +++ b/paddle/fluid/operators/compat/relu.pbtxt @@ -0,0 +1,43 @@ +type: "relu" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } +} +extra { + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "use_cudnn" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } + attrs { + name: "is_test" + type: BOOLEAN + } +} -- GitLab