From 166ff39a20f39ed590a0cf868c2ad2f15cf0bbb1 Mon Sep 17 00:00:00 2001 From: ZeKai Zhou <30856589+zzk0@users.noreply.github.com> Date: Sun, 16 Oct 2022 11:32:11 +0800 Subject: [PATCH] add common subexpression elimination (#44386) --- paddle/fluid/framework/ir/CMakeLists.txt | 5 + .../common_subexpression_elimination_pass.cc | 335 ++++++++++++++++++ .../common_subexpression_elimination_pass.h | 51 +++ ...n_subexpression_elimination_pass_tester.cc | 131 +++++++ .../inference/api/paddle_pass_builder.cc | 1 + 5 files changed, 523 insertions(+) mode change 100755 => 100644 paddle/fluid/framework/ir/CMakeLists.txt create mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc create mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass.h create mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt old mode 100755 new mode 100644 index 08d5e23b6f..4c111ffd9a --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -115,6 +115,7 @@ pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(dense_fc_to_sparse_pass inference) pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) +pass_library(common_subexpression_elimination_pass inference) target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) @@ -326,6 +327,10 @@ cc_test( test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto) +cc_test( + test_common_subexpression_elimination_pass_cc + SRCS common_subexpression_elimination_pass_tester.cc + DEPS common_subexpression_elimination_pass) cc_test( test_delete_dropout_pass_cc SRCS delete_dropout_op_pass_test.cc diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc new file mode 100644 index 0000000000..18c2efd01b --- /dev/null +++ b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc @@ -0,0 +1,335 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/common_subexpression_elimination_pass.h" +#include +#include + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/variant.h" + +namespace { + +std::string NodeTypeToString(paddle::framework::ir::Node::Type type) { + if (type == paddle::framework::ir::Node::Type::kOperation) { + return "kOperation"; + } else { + return "kVariable"; + } +} + +const std::unordered_set commutative_operators{"mul", + "bitwise_and", + "bitwise_or", + "equal_all", + "equal", + "not_equal", + "logical_and", + "logical_or", + "elementwise_max", + "elementwise_fmax", + "elementwise_min", + "elementwise_fmin", + "elementwise_mul", + "elementwise_add", + "add_p", + "max_p", + "mul_p", + "eq_p", + "ne_p"}; + +const std::unordered_set nondeterministic_operators{ + "dropout", + "dropout_nd", + "gaussian_random_batch_size_like", + "gaussian_random", + "randint", + "random_crop", + "random_routing", + "randperm", + "uniform_random_batch_size_like", + "uniform_random_inplace", + "uniform_random", + "fused_bias_dropout_residual_layer_norm"}; + +const std::unordered_set side_effect_operators{ + "feed", "cast", "fetch", "fill_constant", "fill_constant_batch_size_like"}; + +template +inline void HashCombine(std::size_t *seed, const T &v) { + std::hash hasher; + (*seed) ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); +} + +} // namespace + +namespace std { + +#define HASH_ATTRIBUTE(attr, id, type) \ + do { \ + if (attr.index() == id) { \ + return std::hash{}(get(attr)); \ + } \ + } while (0) + +#define HASH_VECTOR_ATTRIBUTE(attr, id, type) \ + do { \ + if (attr.index() == id) { \ + std::vector vec = get(attr); \ + size_t seed = 0; \ + for (const auto &v : vec) { \ + HashCombine(&seed, v); \ + } \ + return seed; \ + } \ + } while (0) + +template <> +struct hash { + size_t operator()(const paddle::framework::proto::VarType_Type &attr) const { + using type = typename std::underlying_type< + paddle::framework::proto::VarType_Type>::type; + return std::hash()(static_cast(attr)); + } +}; + +template <> +struct hash { + size_t operator()(const paddle::framework::Attribute &attr) const { + if (attr.index() == 0) { + return 0; + } + if (attr.index() == 7) { + return static_cast(get<7>(attr)); + } + + HASH_ATTRIBUTE(attr, 1, int); + HASH_ATTRIBUTE(attr, 2, float); + HASH_ATTRIBUTE(attr, 3, std::string); + HASH_VECTOR_ATTRIBUTE(attr, 4, int); + HASH_VECTOR_ATTRIBUTE(attr, 5, float); + HASH_VECTOR_ATTRIBUTE(attr, 6, std::string); + HASH_ATTRIBUTE(attr, 8, std::vector); + HASH_ATTRIBUTE(attr, 9, paddle::framework::BlockDesc *); + HASH_ATTRIBUTE(attr, 10, int64_t); + HASH_VECTOR_ATTRIBUTE(attr, 11, paddle::framework::BlockDesc *); + HASH_VECTOR_ATTRIBUTE(attr, 12, int64_t); + HASH_VECTOR_ATTRIBUTE(attr, 13, double); + return 0; + } +}; +} // namespace std + +namespace paddle { +namespace framework { +namespace ir { + +void CommonSubexpressionEliminationPass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_EQ( + graph->IsMainGraph(), + true, + platform::errors::InvalidArgument( + "CommonSubexpressionEliminationPass only accepts main graph")); + + CommonSubexpressionEliminate( + graph, graph, [](Node *) -> Node * { return nullptr; }); +} + +void CommonSubexpressionEliminationPass::CommonSubexpressionEliminate( + ir::Graph *main_graph, + ir::Graph *graph, + std::function parent_exist_nodes) const { + const char *kSubBlock = "sub_block"; + std::unordered_set exist_nodes; + std::vector nodes = TopologySortOperations(*graph); + for (Node *node : nodes) { + if (node->inputs.empty()) { + continue; + } + if (side_effect_operators.count(node->Name()) != 0) { + continue; + } + if (nondeterministic_operators.count(node->Name()) != 0) { + continue; + } + + if (node->Op()->HasAttr(kSubBlock)) { + auto sub_block_id = + node->Op()->GetAttrIfExists(kSubBlock)->ID(); + CommonSubexpressionEliminate( + main_graph, + main_graph->GetSubGraph(sub_block_id), + [&exist_nodes, &parent_exist_nodes](Node *node) -> Node * { + auto exist_node = exist_nodes.find(node); + if (exist_node != exist_nodes.end()) { + return *exist_node; + } + return parent_exist_nodes(node); + }); + continue; + } + + Node *exist_node = parent_exist_nodes(node); + if (exist_node == nullptr) { + auto res = exist_nodes.insert(node); + if (!res.second) { + exist_node = *res.first; + } + } + + if (exist_node != nullptr) { + for (size_t i = 0; i < exist_node->outputs.size(); ++i) { + Node *exist_node_output = exist_node->outputs[i]; + Node *current_node_output = node->outputs[i]; + std::vector current_node_output_outputs = + current_node_output->outputs; + for (size_t i = 0; i < current_node_output_outputs.size(); ++i) { + IR_NODE_LINK_TO(exist_node_output, current_node_output_outputs[i]); + } + } + GraphSafeRemoveNodes(graph, + std::unordered_set( + node->outputs.begin(), node->outputs.end())); + GraphSafeRemoveNodes(graph, {node}); + } + } +} + +size_t HashOpNode::operator()(const Node *node) const { + PADDLE_ENFORCE_EQ(node->IsOp(), + true, + platform::errors::InvalidArgument( + "HashOpNode only supports operation node type")); + + size_t seed = 0; + std::vector inputs(node->inputs); + if (commutative_operators.count(node->Name()) != 0) { + auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); }; + std::stable_sort(inputs.begin(), inputs.end(), comparator); + } + for (size_t i = 0; i < inputs.size(); ++i) { + HashCombine(&seed, inputs[i]->id()); + HashCombine(&seed, node->GraphId()); + } + const std::string kDepVarName = std::string(Node::kControlDepVarName); + for (size_t i = 0; i < node->outputs.size(); ++i) { + if (node->outputs[i] == nullptr) { + continue; + } + if (node->outputs[i]->IsCtrlVar()) { + HashCombine(&seed, kDepVarName); + } else if (node->outputs[i]->IsVar()) { + HashCombine(&seed, node->outputs[i]->Var()->GetType()); + } + } + OpDesc *desc = node->Op(); + std::vector attributes = desc->AttrNames(); + sort(attributes.begin(), attributes.end()); + for (const std::string &attribute : attributes) { + HashCombine(&seed, desc->GetAttr(attribute)); + } + return seed; +} + +bool EqualOpNode::operator()(const Node *lhs, const Node *rhs) const { + PADDLE_ENFORCE_EQ(lhs->IsOp() && rhs->IsOp(), + true, + platform::errors::InvalidArgument( + "EqualOpNode only supports operation node type")); + + if (lhs == nullptr && rhs == nullptr) { + return true; + } + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs->NodeType() != rhs->NodeType()) { + return false; + } + if (lhs->Name() != rhs->Name()) { + return false; + } + + std::vector lhs_inputs(lhs->inputs); + std::vector rhs_inputs(rhs->inputs); + if (commutative_operators.count(lhs->Name()) != 0) { + auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); }; + std::stable_sort(lhs_inputs.begin(), lhs_inputs.end(), comparator); + std::stable_sort(rhs_inputs.begin(), rhs_inputs.end(), comparator); + } + + // compare inputs value + if (lhs_inputs.size() != rhs_inputs.size()) { + return false; + } + if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) { + return false; + } + + // compare attribute + const OpDesc *lhs_desc = lhs->Op(); + const OpDesc *rhs_desc = rhs->Op(); + std::vector lhs_attr_names = lhs_desc->AttrNames(); + std::vector rhs_attr_names = rhs_desc->AttrNames(); + if (lhs_attr_names.size() != rhs_attr_names.size()) { + return false; + } + std::sort(lhs_attr_names.begin(), lhs_attr_names.end()); + std::sort(rhs_attr_names.begin(), rhs_attr_names.end()); + for (size_t i = 0; i < lhs_attr_names.size(); ++i) { + if (lhs_attr_names[i] != rhs_attr_names[i]) { + return false; + } + if (lhs_desc->GetAttr(lhs_attr_names[i]) != + rhs_desc->GetAttr(rhs_attr_names[i])) { + return false; + } + } + + // compare outputs value type + std::vector lhs_outputs(lhs->outputs); + std::vector rhs_outputs(rhs->outputs); + if (lhs_outputs.size() != rhs_outputs.size()) { + return false; + } + for (size_t i = 0; i < lhs_outputs.size(); ++i) { + if (!lhs_outputs[i]->IsVar() || !rhs_outputs[i]->IsVar()) { + return false; + } + if (lhs_outputs[i]->IsCtrlVar() != rhs_outputs[i]->IsCtrlVar()) { + return false; + } + if (lhs_outputs[i]->IsCtrlVar() && rhs_outputs[i]->IsCtrlVar()) { + continue; + } + if (lhs_outputs[i]->Var()->GetType() != rhs_outputs[i]->Var()->GetType()) { + return false; + } + } + return true; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(common_subexpression_elimination_pass, + paddle::framework::ir::CommonSubexpressionEliminationPass); +REGISTER_PASS_CAPABILITY(common_subexpression_elimination_pass); diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h new file mode 100644 index 0000000000..bc58a51a56 --- /dev/null +++ b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class CommonSubexpressionEliminationPass : public FusePassBase { + public: + CommonSubexpressionEliminationPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void CommonSubexpressionEliminate( + ir::Graph* main_graph, + ir::Graph* graph, + std::function parent_exist_nodes) const; +}; + +struct HashOpNode { + size_t operator()(const Node* node) const; +}; + +struct EqualOpNode { + bool operator()(const Node* lhs, const Node* rhs) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc b/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc new file mode 100644 index 0000000000..abcb123dfd --- /dev/null +++ b/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/common_subexpression_elimination_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(CommonSubexpressionEliminationPass, basic_test) { + // inputs operator output + // -------------------------------------------------------------------- + // (relu(a), b) elementwise_add -> d + // (relu(a), c) elementwise_add -> e + // (d, e) elementwise_add -> f + + Layers layers; + auto* a = layers.data("a", {1024, 768}); + auto* b = layers.data("b", {1024, 768}); + auto* c = layers.data("c", {1024, 768}); + auto* d = layers.elementwise_add(layers.relu(a), b); + auto* e = layers.elementwise_add(layers.relu(a), c); + auto* f = layers.data("f", {1024, 768}); + layers.elementwise_add(d, e, f, 0); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("common_subexpression_elimination_pass"); + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "relu"); + PADDLE_ENFORCE_EQ(num_nodes_after, + 1, + platform::errors::InvalidArgument( + "Before the common subexpression elimination pass, " + "there should be 1 " + "relu op, but the result is %d", + num_nodes_after)); +} + +TEST(CommonSubexpressionEliminationPass, commutative_operator_test) { + // inputs operator output + // -------------------------------------------------------------------- + // (a, b) elementwise_add -> e + // (b, a) elementwise_add -> f + // (e, c) elementwise_add -> g + // (f, d) elementwise_add -> h + + Layers layers; + auto* a = layers.data("a", {1024, 768}); + auto* b = layers.data("b", {1024, 768}); + auto* c = layers.data("c", {1024, 768}); + auto* d = layers.data("d", {1024, 768}); + auto* e = layers.data("e", {1024, 768}); + auto* f = layers.data("f", {1024, 768}); + auto* g = layers.data("g", {1024, 768}); + auto* h = layers.data("h", {1024, 768}); + + layers.elementwise_add(a, b, e, 0); + layers.elementwise_add(b, a, f, 0); + layers.elementwise_add(e, c, g, 0); + layers.elementwise_add(f, d, h, 0); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("common_subexpression_elimination_pass"); + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "elementwise_add"); + PADDLE_ENFORCE_EQ(num_nodes_after, + 3, + platform::errors::InvalidArgument( + "Before the common subexpression elimination pass, " + "there should be 3 " + "elementwise_add op, but the result is %d", + num_nodes_after)); +} + +TEST(CommonSubexpressionEliminationPass, nondeterministic_operator_test) { + // inputs operator output + // -------------------------------------------------------------------- + // (dropout(a), b) elementwise_add -> d + // (dropout(a), c) elementwise_add -> e + // (d, e) elementwise_add -> f + + Layers layers; + auto* a = layers.data("a", {1024, 768}); + auto* b = layers.data("b", {1024, 768}); + auto* c = layers.data("c", {1024, 768}); + auto* d = + layers.elementwise_add(layers.dropout(a, 0.5, "downgrade_in_infer"), b); + auto* e = + layers.elementwise_add(layers.dropout(a, 0.5, "downgrade_in_infer"), c); + auto* f = layers.data("f", {1024, 768}); + layers.elementwise_add(d, e, f, 0); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("common_subexpression_elimination_pass"); + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "dropout"); + PADDLE_ENFORCE_EQ(num_nodes_after, + 2, + platform::errors::InvalidArgument( + "After the common subexpression elimination pass, " + "there should still be 2 " + "dropout op, but the result is %d", + num_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(common_subexpression_elimination_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 133c9b363a..9180cb311d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -282,6 +282,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "conv_transpose_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", // + "common_subexpression_elimination_pass", // "is_test_pass", // "constant_folding_pass", // following pass should be located in the last, since -- GitLab