From 7c6835ca09536e44d16124204855decf3203273f Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 17 Oct 2022 20:19:13 +0800 Subject: [PATCH] Revert "add common subexpression elimination (#44386)" (#47062) This reverts commit 166ff39a20f39ed590a0cf868c2ad2f15cf0bbb1. --- paddle/fluid/framework/ir/CMakeLists.txt | 5 - .../common_subexpression_elimination_pass.cc | 335 ------------------ .../common_subexpression_elimination_pass.h | 51 --- ...n_subexpression_elimination_pass_tester.cc | 131 ------- .../inference/api/paddle_pass_builder.cc | 1 - 5 files changed, 523 deletions(-) mode change 100644 => 100755 paddle/fluid/framework/ir/CMakeLists.txt delete mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc delete mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass.h delete mode 100644 paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt old mode 100644 new mode 100755 index 4c111ffd9a..08d5e23b6f --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -115,7 +115,6 @@ pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(dense_fc_to_sparse_pass inference) pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) -pass_library(common_subexpression_elimination_pass inference) target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) @@ -327,10 +326,6 @@ cc_test( test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto) -cc_test( - test_common_subexpression_elimination_pass_cc - SRCS common_subexpression_elimination_pass_tester.cc - DEPS common_subexpression_elimination_pass) cc_test( test_delete_dropout_pass_cc SRCS delete_dropout_op_pass_test.cc diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc deleted file mode 100644 index 18c2efd01b..0000000000 --- a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.cc +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/common_subexpression_elimination_pass.h" -#include -#include - -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/type_defs.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/utils/variant.h" - -namespace { - -std::string NodeTypeToString(paddle::framework::ir::Node::Type type) { - if (type == paddle::framework::ir::Node::Type::kOperation) { - return "kOperation"; - } else { - return "kVariable"; - } -} - -const std::unordered_set commutative_operators{"mul", - "bitwise_and", - "bitwise_or", - "equal_all", - "equal", - "not_equal", - "logical_and", - "logical_or", - "elementwise_max", - "elementwise_fmax", - "elementwise_min", - "elementwise_fmin", - "elementwise_mul", - "elementwise_add", - "add_p", - "max_p", - "mul_p", - "eq_p", - "ne_p"}; - -const std::unordered_set nondeterministic_operators{ - "dropout", - "dropout_nd", - "gaussian_random_batch_size_like", - "gaussian_random", - "randint", - "random_crop", - "random_routing", - "randperm", - "uniform_random_batch_size_like", - "uniform_random_inplace", - "uniform_random", - "fused_bias_dropout_residual_layer_norm"}; - -const std::unordered_set side_effect_operators{ - "feed", "cast", "fetch", "fill_constant", "fill_constant_batch_size_like"}; - -template -inline void HashCombine(std::size_t *seed, const T &v) { - std::hash hasher; - (*seed) ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); -} - -} // namespace - -namespace std { - -#define HASH_ATTRIBUTE(attr, id, type) \ - do { \ - if (attr.index() == id) { \ - return std::hash{}(get(attr)); \ - } \ - } while (0) - -#define HASH_VECTOR_ATTRIBUTE(attr, id, type) \ - do { \ - if (attr.index() == id) { \ - std::vector vec = get(attr); \ - size_t seed = 0; \ - for (const auto &v : vec) { \ - HashCombine(&seed, v); \ - } \ - return seed; \ - } \ - } while (0) - -template <> -struct hash { - size_t operator()(const paddle::framework::proto::VarType_Type &attr) const { - using type = typename std::underlying_type< - paddle::framework::proto::VarType_Type>::type; - return std::hash()(static_cast(attr)); - } -}; - -template <> -struct hash { - size_t operator()(const paddle::framework::Attribute &attr) const { - if (attr.index() == 0) { - return 0; - } - if (attr.index() == 7) { - return static_cast(get<7>(attr)); - } - - HASH_ATTRIBUTE(attr, 1, int); - HASH_ATTRIBUTE(attr, 2, float); - HASH_ATTRIBUTE(attr, 3, std::string); - HASH_VECTOR_ATTRIBUTE(attr, 4, int); - HASH_VECTOR_ATTRIBUTE(attr, 5, float); - HASH_VECTOR_ATTRIBUTE(attr, 6, std::string); - HASH_ATTRIBUTE(attr, 8, std::vector); - HASH_ATTRIBUTE(attr, 9, paddle::framework::BlockDesc *); - HASH_ATTRIBUTE(attr, 10, int64_t); - HASH_VECTOR_ATTRIBUTE(attr, 11, paddle::framework::BlockDesc *); - HASH_VECTOR_ATTRIBUTE(attr, 12, int64_t); - HASH_VECTOR_ATTRIBUTE(attr, 13, double); - return 0; - } -}; -} // namespace std - -namespace paddle { -namespace framework { -namespace ir { - -void CommonSubexpressionEliminationPass::ApplyImpl(ir::Graph *graph) const { - PADDLE_ENFORCE_EQ( - graph->IsMainGraph(), - true, - platform::errors::InvalidArgument( - "CommonSubexpressionEliminationPass only accepts main graph")); - - CommonSubexpressionEliminate( - graph, graph, [](Node *) -> Node * { return nullptr; }); -} - -void CommonSubexpressionEliminationPass::CommonSubexpressionEliminate( - ir::Graph *main_graph, - ir::Graph *graph, - std::function parent_exist_nodes) const { - const char *kSubBlock = "sub_block"; - std::unordered_set exist_nodes; - std::vector nodes = TopologySortOperations(*graph); - for (Node *node : nodes) { - if (node->inputs.empty()) { - continue; - } - if (side_effect_operators.count(node->Name()) != 0) { - continue; - } - if (nondeterministic_operators.count(node->Name()) != 0) { - continue; - } - - if (node->Op()->HasAttr(kSubBlock)) { - auto sub_block_id = - node->Op()->GetAttrIfExists(kSubBlock)->ID(); - CommonSubexpressionEliminate( - main_graph, - main_graph->GetSubGraph(sub_block_id), - [&exist_nodes, &parent_exist_nodes](Node *node) -> Node * { - auto exist_node = exist_nodes.find(node); - if (exist_node != exist_nodes.end()) { - return *exist_node; - } - return parent_exist_nodes(node); - }); - continue; - } - - Node *exist_node = parent_exist_nodes(node); - if (exist_node == nullptr) { - auto res = exist_nodes.insert(node); - if (!res.second) { - exist_node = *res.first; - } - } - - if (exist_node != nullptr) { - for (size_t i = 0; i < exist_node->outputs.size(); ++i) { - Node *exist_node_output = exist_node->outputs[i]; - Node *current_node_output = node->outputs[i]; - std::vector current_node_output_outputs = - current_node_output->outputs; - for (size_t i = 0; i < current_node_output_outputs.size(); ++i) { - IR_NODE_LINK_TO(exist_node_output, current_node_output_outputs[i]); - } - } - GraphSafeRemoveNodes(graph, - std::unordered_set( - node->outputs.begin(), node->outputs.end())); - GraphSafeRemoveNodes(graph, {node}); - } - } -} - -size_t HashOpNode::operator()(const Node *node) const { - PADDLE_ENFORCE_EQ(node->IsOp(), - true, - platform::errors::InvalidArgument( - "HashOpNode only supports operation node type")); - - size_t seed = 0; - std::vector inputs(node->inputs); - if (commutative_operators.count(node->Name()) != 0) { - auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); }; - std::stable_sort(inputs.begin(), inputs.end(), comparator); - } - for (size_t i = 0; i < inputs.size(); ++i) { - HashCombine(&seed, inputs[i]->id()); - HashCombine(&seed, node->GraphId()); - } - const std::string kDepVarName = std::string(Node::kControlDepVarName); - for (size_t i = 0; i < node->outputs.size(); ++i) { - if (node->outputs[i] == nullptr) { - continue; - } - if (node->outputs[i]->IsCtrlVar()) { - HashCombine(&seed, kDepVarName); - } else if (node->outputs[i]->IsVar()) { - HashCombine(&seed, node->outputs[i]->Var()->GetType()); - } - } - OpDesc *desc = node->Op(); - std::vector attributes = desc->AttrNames(); - sort(attributes.begin(), attributes.end()); - for (const std::string &attribute : attributes) { - HashCombine(&seed, desc->GetAttr(attribute)); - } - return seed; -} - -bool EqualOpNode::operator()(const Node *lhs, const Node *rhs) const { - PADDLE_ENFORCE_EQ(lhs->IsOp() && rhs->IsOp(), - true, - platform::errors::InvalidArgument( - "EqualOpNode only supports operation node type")); - - if (lhs == nullptr && rhs == nullptr) { - return true; - } - if (lhs == nullptr || rhs == nullptr) { - return false; - } - if (lhs->NodeType() != rhs->NodeType()) { - return false; - } - if (lhs->Name() != rhs->Name()) { - return false; - } - - std::vector lhs_inputs(lhs->inputs); - std::vector rhs_inputs(rhs->inputs); - if (commutative_operators.count(lhs->Name()) != 0) { - auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); }; - std::stable_sort(lhs_inputs.begin(), lhs_inputs.end(), comparator); - std::stable_sort(rhs_inputs.begin(), rhs_inputs.end(), comparator); - } - - // compare inputs value - if (lhs_inputs.size() != rhs_inputs.size()) { - return false; - } - if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) { - return false; - } - - // compare attribute - const OpDesc *lhs_desc = lhs->Op(); - const OpDesc *rhs_desc = rhs->Op(); - std::vector lhs_attr_names = lhs_desc->AttrNames(); - std::vector rhs_attr_names = rhs_desc->AttrNames(); - if (lhs_attr_names.size() != rhs_attr_names.size()) { - return false; - } - std::sort(lhs_attr_names.begin(), lhs_attr_names.end()); - std::sort(rhs_attr_names.begin(), rhs_attr_names.end()); - for (size_t i = 0; i < lhs_attr_names.size(); ++i) { - if (lhs_attr_names[i] != rhs_attr_names[i]) { - return false; - } - if (lhs_desc->GetAttr(lhs_attr_names[i]) != - rhs_desc->GetAttr(rhs_attr_names[i])) { - return false; - } - } - - // compare outputs value type - std::vector lhs_outputs(lhs->outputs); - std::vector rhs_outputs(rhs->outputs); - if (lhs_outputs.size() != rhs_outputs.size()) { - return false; - } - for (size_t i = 0; i < lhs_outputs.size(); ++i) { - if (!lhs_outputs[i]->IsVar() || !rhs_outputs[i]->IsVar()) { - return false; - } - if (lhs_outputs[i]->IsCtrlVar() != rhs_outputs[i]->IsCtrlVar()) { - return false; - } - if (lhs_outputs[i]->IsCtrlVar() && rhs_outputs[i]->IsCtrlVar()) { - continue; - } - if (lhs_outputs[i]->Var()->GetType() != rhs_outputs[i]->Var()->GetType()) { - return false; - } - } - return true; -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(common_subexpression_elimination_pass, - paddle::framework::ir::CommonSubexpressionEliminationPass); -REGISTER_PASS_CAPABILITY(common_subexpression_elimination_pass); diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h b/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h deleted file mode 100644 index bc58a51a56..0000000000 --- a/paddle/fluid/framework/ir/common_subexpression_elimination_pass.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" - -namespace paddle { -namespace framework { -namespace ir { - -class Graph; - -class CommonSubexpressionEliminationPass : public FusePassBase { - public: - CommonSubexpressionEliminationPass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; - - private: - void CommonSubexpressionEliminate( - ir::Graph* main_graph, - ir::Graph* graph, - std::function parent_exist_nodes) const; -}; - -struct HashOpNode { - size_t operator()(const Node* node) const; -}; - -struct EqualOpNode { - bool operator()(const Node* lhs, const Node* rhs) const; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc b/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc deleted file mode 100644 index abcb123dfd..0000000000 --- a/paddle/fluid/framework/ir/common_subexpression_elimination_pass_tester.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include -#include -#include -#include - -#include "paddle/fluid/framework/ir/common_subexpression_elimination_pass.h" -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/pass_tester_helper.h" -#include "paddle/fluid/framework/op_version_registry.h" - -namespace paddle { -namespace framework { -namespace ir { - -TEST(CommonSubexpressionEliminationPass, basic_test) { - // inputs operator output - // -------------------------------------------------------------------- - // (relu(a), b) elementwise_add -> d - // (relu(a), c) elementwise_add -> e - // (d, e) elementwise_add -> f - - Layers layers; - auto* a = layers.data("a", {1024, 768}); - auto* b = layers.data("b", {1024, 768}); - auto* c = layers.data("c", {1024, 768}); - auto* d = layers.elementwise_add(layers.relu(a), b); - auto* e = layers.elementwise_add(layers.relu(a), c); - auto* f = layers.data("f", {1024, 768}); - layers.elementwise_add(d, e, f, 0); - - std::unique_ptr graph(new ir::Graph(layers.main_program())); - auto pass = - PassRegistry::Instance().Get("common_subexpression_elimination_pass"); - graph.reset(pass->Apply(graph.release())); - int num_nodes_after = GetNumOpNodes(graph, "relu"); - PADDLE_ENFORCE_EQ(num_nodes_after, - 1, - platform::errors::InvalidArgument( - "Before the common subexpression elimination pass, " - "there should be 1 " - "relu op, but the result is %d", - num_nodes_after)); -} - -TEST(CommonSubexpressionEliminationPass, commutative_operator_test) { - // inputs operator output - // -------------------------------------------------------------------- - // (a, b) elementwise_add -> e - // (b, a) elementwise_add -> f - // (e, c) elementwise_add -> g - // (f, d) elementwise_add -> h - - Layers layers; - auto* a = layers.data("a", {1024, 768}); - auto* b = layers.data("b", {1024, 768}); - auto* c = layers.data("c", {1024, 768}); - auto* d = layers.data("d", {1024, 768}); - auto* e = layers.data("e", {1024, 768}); - auto* f = layers.data("f", {1024, 768}); - auto* g = layers.data("g", {1024, 768}); - auto* h = layers.data("h", {1024, 768}); - - layers.elementwise_add(a, b, e, 0); - layers.elementwise_add(b, a, f, 0); - layers.elementwise_add(e, c, g, 0); - layers.elementwise_add(f, d, h, 0); - - std::unique_ptr graph(new ir::Graph(layers.main_program())); - auto pass = - PassRegistry::Instance().Get("common_subexpression_elimination_pass"); - graph.reset(pass->Apply(graph.release())); - int num_nodes_after = GetNumOpNodes(graph, "elementwise_add"); - PADDLE_ENFORCE_EQ(num_nodes_after, - 3, - platform::errors::InvalidArgument( - "Before the common subexpression elimination pass, " - "there should be 3 " - "elementwise_add op, but the result is %d", - num_nodes_after)); -} - -TEST(CommonSubexpressionEliminationPass, nondeterministic_operator_test) { - // inputs operator output - // -------------------------------------------------------------------- - // (dropout(a), b) elementwise_add -> d - // (dropout(a), c) elementwise_add -> e - // (d, e) elementwise_add -> f - - Layers layers; - auto* a = layers.data("a", {1024, 768}); - auto* b = layers.data("b", {1024, 768}); - auto* c = layers.data("c", {1024, 768}); - auto* d = - layers.elementwise_add(layers.dropout(a, 0.5, "downgrade_in_infer"), b); - auto* e = - layers.elementwise_add(layers.dropout(a, 0.5, "downgrade_in_infer"), c); - auto* f = layers.data("f", {1024, 768}); - layers.elementwise_add(d, e, f, 0); - - std::unique_ptr graph(new ir::Graph(layers.main_program())); - auto pass = - PassRegistry::Instance().Get("common_subexpression_elimination_pass"); - graph.reset(pass->Apply(graph.release())); - int num_nodes_after = GetNumOpNodes(graph, "dropout"); - PADDLE_ENFORCE_EQ(num_nodes_after, - 2, - platform::errors::InvalidArgument( - "After the common subexpression elimination pass, " - "there should still be 2 " - "dropout op, but the result is %d", - num_nodes_after)); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(common_subexpression_elimination_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1c424d6c75..35537cd1fc 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -282,7 +282,6 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "conv_transpose_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", // - "common_subexpression_elimination_pass", // "is_test_pass", // "constant_folding_pass", // following pass should be located in the last, since -- GitLab