// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { namespace ir { class operation_visitor : public boost::static_visitor { public: explicit operation_visitor(const proto::PassDesc::OperationType& type) : type_(type) {} template Attribute operator()(const T1& attr, const T2& operation) const { PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand.")); } template ::value || std::is_floating_point::value>* = nullptr> Attribute operator()(const T& attr, const T& operation) const { switch (type_) { case proto::PassDesc_OperationType_kSub: { return attr - operation; } default: PADDLE_THROW( platform::errors::Unimplemented("Unimplemented operation type.")); } } private: proto::PassDesc::OperationType type_; }; Attribute GetVarAttrValue(const VarDesc* desc, const proto::PassDesc::Attr& attr) { if ("shape" == attr.name()) { std::vector shape = desc->GetShape(); if (attr.has_operation()) { if (attr.operation() == proto::PassDesc_OperationType_kSize) { return static_cast(shape.size()); } } else if (attr.has_element_index()) { int element_index = attr.element_index(); if (attr.element_index() < 0) { element_index += shape.size(); } if (element_index >= 0 && static_cast(element_index) < shape.size()) { return static_cast(shape[element_index]); } } else { return shape; } } return boost::blank(); } void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { // Traverse all operators to create subgraph. for (int index = 0; index < pass_desc.pattern_size(); ++index) { const proto::OpDesc& op = pass_desc.pattern(index); // Create a PDNode for current operator. Use the index as name to avoid // multiple operators with same type. Get a PDNode from pattern subgraph // through index in rewrite phase. PDNode* op_pdnode = pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); // Create PDNodes for inputs of current operator. for (const proto::OpDesc::Var& var : op.inputs()) { for (int n = 0; n < var.arguments_size(); ++n) { const std::string& argument = var.arguments(n); // The input may be the output of other operator. PDNode* var_pdnode = pattern->RetrieveNode(argument); if (nullptr == var_pdnode) { var_pdnode = pattern->NewNode(argument)->AsInput(); var_pdnode->assert_is_var(); } else if (var_pdnode->IsOutput()) { var_pdnode->AsIntermediate(); } var_pdnode->assert_more([&](Node* x) { for (auto* out : x->outputs) { if (out->IsOp() && out->Op()->Type() == op.type()) { const auto& inputs = out->Op()->Inputs(); const auto& iter = inputs.find(var.parameter()); if (inputs.end() != iter) { if (iter->second.end() != std::find(iter->second.begin(), iter->second.end(), x->Name())) { return true; } } } } return false; }); pattern->AddEdge(var_pdnode, op_pdnode); } } // Create PDNodes for outputs of current operator. for (const proto::OpDesc::Var& var : op.outputs()) { for (const std::string& argument : var.arguments()) { // The output may be the input of other operator. PDNode* var_pdnode = pattern->RetrieveNode(argument); if (nullptr == var_pdnode) { var_pdnode = pattern->NewNode(argument)->AsOutput(); var_pdnode->assert_is_var(); var_pdnode->assert_more([&](Node* x) { for (Node* input : x->inputs) { if (input && input->IsOp() && input->Op() && input->Op()->Type() == op.type()) { const auto& outputs = input->Op()->Outputs(); const auto& iter = outputs.find(var.parameter()); if (outputs.end() != iter) { if (iter->second.end() != std::find(iter->second.begin(), iter->second.end(), x->Name())) { return true; } } } } return false; }); } else if (var_pdnode->IsInput()) { var_pdnode->AsIntermediate(); } var_pdnode->assert_is_op_output(op.type()); pattern->AddEdge(op_pdnode, var_pdnode); } } // Set attribute condition for current operator. for (const proto::OpDesc::Attr& attr : op.attrs()) { op_pdnode->assert_more([&](Node* x) { if (x && x->IsOp()) { OpDesc* op_desc = x->Op(); if (op_desc->HasAttr(attr.name())) { return GetAttrValue(attr) == op_desc->GetAttr(attr.name()); } return false; } return false; }); } } for (const auto& condition : pass_desc.var_attr_conditions()) { if (condition.has_condition_value()) { PDNode* pdnode = pattern->RetrieveNode(condition.attr().var_name()); pdnode->assert_more([&](Node* x) { Attribute attr = GetVarAttrValue(x->Var(), condition.attr()); switch (condition.type()) { case proto::PassDesc_ConditionType_kEQ: { return attr == GetAttrValue(condition.condition_value()); } default: PADDLE_THROW(platform::errors::Unimplemented( "Unimplemented condition type.")); } }); } } } // There are some duplicate patterns. bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { for (auto iter : subgraph) { if (nullptr == graph->RetrieveNode(iter.second->id())) { VLOG(3) << "Node [" << iter.second->Name() << "] of subgraph has been removed. So skip this optimize."; return true; } } return false; } GraphPatternDetector::handle_t GetGenerateDelete( const PDPattern& pattern, const proto::PassDesc& pass_desc) { GraphPatternDetector::handle_t handler = [&]( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { if (IsDuplicatePattern(subgraph, graph)) { return; } // `var_node_maps` record the mapping of variable to the pattern subgraph. std::map var_node_maps; for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); const auto& iter = var_node_maps.find(var_map.replace_var()); if (var_node_maps.end() == iter) { // first node is input var_node_maps.insert({var_map.replace_var(), node}); } else { // output node for (Node* s_node : node->outputs) { iter->second->outputs.push_back(s_node); std::replace(s_node->inputs.begin(), s_node->inputs.end(), node, iter->second); s_node->Op()->RenameInput(node->Name(), iter->second->Name()); } } } // Remove nodes that are intermediate. std::unordered_set remove_nodes; for (const std::unique_ptr& pdnode : pattern.nodes()) { remove_nodes.emplace(subgraph.at(pdnode.get())); } for (auto iter : var_node_maps) { remove_nodes.erase(iter.second); } GraphSafeRemoveNodes(graph, remove_nodes); }; return handler; } GraphPatternDetector::handle_t GetGenerateRewrite( const PDPattern& pattern, const proto::PassDesc& pass_desc) { GraphPatternDetector::handle_t handler = [&]( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { if (IsDuplicatePattern(subgraph, graph)) { return; } for (const auto& condition : pass_desc.var_attr_conditions()) { if (condition.has_condition_attr()) { Node* node = subgraph.at(pattern.RetrieveNode(condition.attr().var_name())); Attribute node_attr = GetVarAttrValue(node->Var(), condition.attr()); Attribute condition_attr; if (condition.condition_attr().role() == proto::PassDesc_RoleType_kVariable) { Node* condition_node = subgraph.at(pattern.RetrieveNode(condition.attr().var_name())); condition_attr = GetVarAttrValue(condition_node->Var(), condition.condition_attr()); } else { PADDLE_THROW( platform::errors::Unimplemented("Unimplemented for operation.")); } bool check_failed = false; if (condition.type() == proto::PassDesc_ConditionType_kEQ) { check_failed = !(node_attr == condition_attr); } if (check_failed) { VLOG(3) << "Check var [" << node->Name() << "] with attr [" << condition.attr().name() << "] failed, skip this pattern."; return; } } } // `var_node_maps` record the mapping of variable to the pattern subgraph. std::map var_node_maps; for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); var_node_maps.insert({var_map.replace_var(), node}); } // Traverse all operators to create subgraph. for (int index = 0; index < pass_desc.replace_size(); ++index) { const proto::OpDesc& op = pass_desc.replace(index); OpDesc op_desc; std::vector in_nodes, out_nodes; op_desc.SetType(op.type()); // Create Nodes for inputs of current operator. for (const proto::OpDesc::Var& var : op.inputs()) { std::vector arguments; for (const std::string& argument : var.arguments()) { // The input may be mapped on the operator of pattern subgraph. Node* node = nullptr; auto iter = var_node_maps.find(argument); if (var_node_maps.end() == iter) { VarDesc var_desc(patterns::UniqueKey(argument)); node = graph->CreateVarNode(&var_desc); var_node_maps.insert({argument, node}); } else { node = iter->second; } in_nodes.push_back(node); arguments.push_back(node->Name()); } op_desc.SetInput(var.parameter(), arguments); } // Create Nodes for outputs of current operator. for (const proto::OpDesc::Var& var : op.outputs()) { std::vector arguments; for (const std::string& argument : var.arguments()) { // The output may be mapped on the operator of pattern subgraph. Node* node = nullptr; auto iter = var_node_maps.find(argument); if (var_node_maps.end() == iter) { VarDesc var_desc(patterns::UniqueKey(argument)); node = graph->CreateVarNode(&var_desc); var_node_maps.insert({argument, node}); } else { node = iter->second; } out_nodes.push_back(node); arguments.push_back(node->Name()); } op_desc.SetOutput(var.parameter(), arguments); } // Set attribute for current operator. for (const proto::OpDesc::Attr& attr : op.attrs()) { op_desc.SetAttr(attr.name(), GetAttrValue(attr)); } for (const auto& attr_map : pass_desc.op_attr_maps()) { if (attr_map.replace_attr().op_index() == index) { Attribute attr; if (attr_map.pattern_attr().role() == proto::PassDesc_RoleType_kVariable) { Node* condition_node = subgraph.at( pattern.RetrieveNode(attr_map.pattern_attr().var_name())); attr = GetVarAttrValue(condition_node->Var(), attr_map.pattern_attr()); } else { Node* condition_node = subgraph.at(pattern.RetrieveNode( std::to_string(attr_map.pattern_attr().op_index()))); attr = condition_node->Op()->GetAttr(attr_map.pattern_attr().name()); } if (attr_map.has_operation()) { Attribute operation = GetAttrValue(attr_map.operation().value()); attr = boost::apply_visitor( operation_visitor(attr_map.operation().type()), attr, operation); } op_desc.SetAttr(attr_map.replace_attr().name(), attr); } } // Create a Node for current operator. Node* op_node = graph->CreateOpNode(&op_desc); for (Node* node : in_nodes) { IR_NODE_LINK_TO(node, op_node); } for (Node* node : out_nodes) { IR_NODE_LINK_TO(op_node, node); } } // Remove nodes that are intermediate. std::unordered_set remove_nodes; for (const std::unique_ptr& pdnode : pattern.nodes()) { remove_nodes.emplace(subgraph.at(pdnode.get())); } for (auto iter : var_node_maps) { remove_nodes.erase(iter.second); } GraphSafeRemoveNodes(graph, remove_nodes); }; return handler; } GeneratePass::GeneratePass(const std::string& binary_str) { multi_pass_desc_.ParseFromString(binary_str); VerifyDesc(); } GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) : multi_pass_desc_(multi_pass_desc) { VerifyDesc(); } void GeneratePass::ApplyImpl(Graph* graph) const { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { GraphPatternDetector detector; InitGeneratePattern(pass_desc, detector.mutable_pattern()); if (pass_desc.replace_size() == 0) { detector(graph, GetGenerateDelete(detector.pattern(), pass_desc)); } else { detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); } // The rewrited graph needs to be verified. Current Pass should be skipped // if validation failed. Rewrite based on the original graph cannot // implement rollback operation. VerifyGraph(*graph); } } void GeneratePass::VerifyDesc() const { PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, platform::errors::InvalidArgument( "Size of PassDesc should not be empty.")); } bool GeneratePass::VerifyGraph(const Graph& graph) { // Return true temporarily. return true; } namespace generate_pass { VarHelper::VarHelper(const char* name) : name_(name), type_(Type::kInput) {} VarHelper::VarHelper(const std::string& name, Type type) : name_(name), type_(type) {} OpHelper::OpHelper(const char* type, SubgraphHelper* subgraph_helper) : type_(type), subgraph_helper_(subgraph_helper) { op_desc_ = subgraph_helper_->ProgramDesc()->mutable_blocks(0)->add_ops(); op_desc_->set_type(type_); } OpHelper::Arguments::Arguments(const char* parameter, const VarHelper& var_helper) : parameter_(parameter) { var_helpers_.push_back(var_helper); } OpHelper::Arguments::Arguments(const char* parameter, std::initializer_list var_helpers) : parameter_(parameter), var_helpers_(var_helpers) {} OpHelper& OpHelper::operator()(const Arguments& input) { proto::OpDesc::Var* var = op_desc_->add_inputs(); var->set_parameter(input.parameter_); for (const VarHelper& var_helper : input.var_helpers_) { var->add_arguments()->assign(var_helper.name_); if (VarHelper::Type::kInput == var_helper.type_) { subgraph_helper_->AddInputVar(var_helper.name_); } } return *this; } OpHelper& OpHelper::operator()(std::initializer_list inputs) { for (const auto& input : inputs) { operator()(input); } return *this; } VarHelper OpHelper::Out(const char* name) { std::string argument = patterns::UniqueKey(type_); proto::OpDesc::Var* var = op_desc_->add_outputs(); var->set_parameter(name); var->add_arguments()->assign(argument); return VarHelper(argument, VarHelper::Type::kOutput); } proto::ProgramDesc* SubgraphHelper::ProgramDesc() { return &program_desc_; } const proto::ProgramDesc& SubgraphHelper::ProgramDesc() const { return program_desc_; } const std::vector& SubgraphHelper::InputVars() const { return input_vars_; } const std::vector& SubgraphHelper::OutputVars() const { return output_vars_; } void SubgraphHelper::AddInputVar(const std::string& name) { auto iter = std::find(input_vars_.begin(), input_vars_.end(), name); if (input_vars_.end() == iter) { input_vars_.push_back(name); } } void SubgraphHelper::AddOutputVars(const VarHelper& var_helper) { output_vars_.push_back(var_helper.name_); } } // namespace generate_pass PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) { AddPassDesc(pattern, replace); } void PassPairs::AddPassDesc(const SubgraphType& pattern, const SubgraphType& replace) { proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs(); pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc().blocks(0).ops()); pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc().blocks(0).ops()); PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(), platform::errors::InvalidArgument( "Size of lambda expression arguments is not equal " "between pattern/replace subgraph.")); for (size_t i = 0; i < pattern.InputVars().size(); i++) { proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); var_map->set_pattern_var(pattern.InputVars()[i]); var_map->set_replace_var(replace.InputVars()[i]); } PADDLE_ENFORCE_EQ(pattern.OutputVars().size(), replace.OutputVars().size(), platform::errors::InvalidArgument( "Size of lambda expression returns is not equal " "between pattern/replace subgraph.")); for (size_t i = 0; i < pattern.OutputVars().size(); i++) { proto::PassDesc::VarMap* var_map = pass_desc->add_var_maps(); var_map->set_pattern_var(pattern.OutputVars()[i]); var_map->set_replace_var(replace.OutputVars()[i]); } } const proto::MultiPassDesc& PassPairs::MultiPassDesc() const { return multi_pass_desc_; } } // namespace ir } // namespace framework } // namespace paddle