From 5e27d16d3b85d18c34130c776b73616ee79e8744 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Wed, 28 Jul 2021 10:14:46 +0800 Subject: [PATCH] apply pass strategy to sub graph (#34158) When Graph has sub-graph, apply pass to it and all sub-graph. And add single test script . --- paddle/fluid/framework/details/CMakeLists.txt | 4 +- .../fluid/framework/details/build_strategy.cc | 15 +- .../framework/details/build_strategy_test.cc | 300 ++++++++++++++++++ 3 files changed, 317 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/details/build_strategy_test.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 219b6581fd..23a3ee2c58 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -134,11 +134,13 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass modify_op_lock_and_record_event_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass - sync_batch_norm_pass runtime_context_cache_pass) + sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) endif() cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS}) +cc_test(build_strategy_test SRCS build_strategy_test.cc + DEPS build_strategy op_registry op_proto_maker graph) if (WITH_MKLDNN) target_link_libraries(build_strategy mkldnn_placement_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 5636e3ed1b..b517c23275 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" DECLARE_bool(use_mkldnn); +DECLARE_bool(convert_all_blocks); namespace paddle { namespace framework { @@ -312,6 +313,11 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, DeviceType use_device) const { #endif VLOG(1) << "apply all passes"; + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + graph->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + } // Create a default one if not finalized by user. CreatePassesFromStrategy(false); @@ -432,7 +438,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, } } VLOG(1) << "Start Apply Pass " << pass->Type(); - graph = pass->Apply(graph); + if (FLAGS_convert_all_blocks) { + for (size_t i = 0; i < graph->SubGraphsSize(); ++i) { + VLOG(3) << "Apply Pass " << pass->Type() << "to SubGraph " << i; + pass->Apply(graph->GetSubGraph(i)); + } + } else { + graph = pass->Apply(graph); + } VLOG(1) << "Finish Apply Pass " << pass->Type(); } VLOG(1) << "All Passes Applied"; diff --git a/paddle/fluid/framework/details/build_strategy_test.cc b/paddle/fluid/framework/details/build_strategy_test.cc new file mode 100644 index 0000000000..69af77d23f --- /dev/null +++ b/paddle/fluid/framework/details/build_strategy_test.cc @@ -0,0 +1,300 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest-message.h" +#include "gtest/gtest-test-part.h" +#include "gtest/gtest.h" +#include "gtest/gtest_pred_impl.h" + +#include "paddle/fluid/framework/details/build_strategy.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/platform/place.h" + +DECLARE_bool(convert_all_blocks); + +namespace paddle { +namespace framework { + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class SumOpWithKernel : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext &ctx) const override { + return OpKernelType(proto::VarType::FP32, ctx.Input("X")->place()); + } +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(sum, paddle::framework::SumOpWithKernel, + paddle::framework::SumOpMaker); + +namespace paddle { +namespace framework { +namespace details { + +static std::vector CreatePlaces(size_t num, bool use_cuda) { + std::vector result; + result.reserve(num); + for (size_t i = 0; i < num; ++i) { + if (use_cuda) { + result.emplace_back(platform::CUDAPlace(i)); + } else { + result.emplace_back(platform::CPUPlace()); + } + } + return result; +} + +void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) { + std::string loss_name = ""; + Scope scope; + std::vector scopes = {&scope}; + + auto places = CreatePlaces(1, false); + auto device = platform::Place2DeviceType(places[0]); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + platform::NCCLCommunicator ctxs; +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) + platform::BKCLCommunicator ctxs; +#endif + + build_strategy->Apply(graph, places, loss_name, scopes, 1, +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + device, &ctxs); +#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) + device, &ctxs); +#else + device); +#endif +} + +std::unique_ptr CreateGraph() { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a1"}); + op->SetOutput("Out", {"b1"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + return g; +} + +std::unique_ptr CreateMultiGraph() { + ProgramDesc prog; + prog.AppendBlock(prog.Block(0)); + prog.AppendBlock(prog.Block(0)); + + // Set contents in block_0. + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + op->InferVarType(prog.MutableBlock(0)); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + + // Set contents in block_1. + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a1"}); + op->SetOutput("Out", {"b1"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(1)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); + + // Set contents in block_2. + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a2"}); + op->SetOutput("Out", {"b2"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(2)->Var("a2")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("b2")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + return g; +} + +inline bool CheckSubGraphSame(ir::Graph *g1, ir::Graph *g2) { + const auto &g1_nodes_set = g1->Nodes(); + const auto &g2_nodes_set = g2->Nodes(); + + if (g1_nodes_set.size() != g2_nodes_set.size()) return false; + + std::vector g1_nodes(g1_nodes_set.begin(), g1_nodes_set.end()); + std::vector g2_nodes(g2_nodes_set.begin(), g2_nodes_set.end()); + + auto comp = [](ir::Node *n1, ir::Node *n2) { + return n1->Name() > n2->Name(); + }; + std::stable_sort(g1_nodes.begin(), g1_nodes.end(), comp); + std::stable_sort(g2_nodes.begin(), g2_nodes.end(), comp); + + for (size_t i = 0; i < g1_nodes.size(); ++i) { + const auto &n1 = g1_nodes[i]; + const auto &n2 = g2_nodes[i]; + + if (n1->NodeType() != n2->NodeType()) return false; + if (n1->Name() != n2->Name()) return false; + + auto n1_inputs = n1->inputs; + auto n2_inputs = n2->inputs; + if (n1_inputs.size() != n2_inputs.size()) return false; + + std::stable_sort(n1_inputs.begin(), n1_inputs.end(), comp); + std::stable_sort(n2_inputs.begin(), n2_inputs.end(), comp); + for (size_t i = 0; i < n1_inputs.size(); ++i) { + if (n1_inputs[i]->Name() != n2_inputs[i]->Name()) return false; + } + + auto n1_outputs = n1->outputs; + auto n2_outputs = n2->outputs; + if (n1_outputs.size() != n2_outputs.size()) return false; + + std::stable_sort(n1_outputs.begin(), n1_outputs.end(), comp); + std::stable_sort(n2_outputs.begin(), n2_outputs.end(), comp); + for (size_t i = 0; i < n1_outputs.size(); ++i) { + if (n1_outputs[i]->Name() != n2_outputs[i]->Name()) return false; + } + + if (n1->IsVar()) { + const auto &var1 = n1->Var(); + const auto &var2 = n2->Var(); + if ((var1 == nullptr) != (var2 == nullptr)) return false; + } + + if (n1->IsOp()) { + const auto &op1 = n1->Op(); + const auto &op2 = n2->Op(); + if ((op1 == nullptr) != (op2 == nullptr)) return false; + + const auto &op1_input = op1->InputNames(); + const auto &op2_input = op2->InputNames(); + if (op1_input.size() != op2_input.size()) return false; + if (op1_input != op2_input) return false; + + for (size_t i = 0; i < op1_input.size(); ++i) { + if (op1->Input(op1_input[i]) != op2->Input(op2_input[i])) return false; + } + + const auto &op1_output = op1->OutputNames(); + const auto &op2_output = op2->OutputNames(); + if (op1_output.size() != op2_output.size()) return false; + if (op1_output != op2_output) return false; + + for (size_t i = 0; i < op1_output.size(); ++i) { + if (op1->Output(op1_output[i]) != op2->Output(op2_output[i])) + return false; + } + } + } + return true; +} + +inline bool CheckGraphSame(ir::Graph *g1, ir::Graph *g2) { + if (g1 == nullptr || g2 == nullptr) return true; + + if (FLAGS_convert_all_blocks) { + if (g1->SubGraphsSize() != g2->SubGraphsSize()) return false; + + for (size_t i = 0; i < g1->SubGraphsSize(); ++i) { + if (!CheckSubGraphSame(g1->GetSubGraph(i), g2->GetSubGraph(i))) + return false; + } + } else { + if (!CheckSubGraphSame(g1, g2)) return false; + } + return true; +} + +TEST(BuildStrategy, Basic) { + BuildStrategy build_strategy; + + ProgramDesc prog; + ir::Graph old_graph(prog), graph(prog); + + BuildStrategyApply(&build_strategy, &graph); + + ASSERT_TRUE(CheckGraphSame(&old_graph, &graph)); +} + +TEST(BuildStrategy, TestSingleGraph) { + BuildStrategy build_strategy; + auto graph = CreateGraph(); + ir::Graph old_graph(graph->OriginProgram()); + + BuildStrategyApply(&build_strategy, graph.get()); + + // graph should not change for no pass here + ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); +} + +TEST(BuildStrategy, TestMultiGraph) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + BuildStrategy build_strategy; + auto graph = CreateMultiGraph(); + ir::Graph old_graph(graph->OriginProgram()); + + BuildStrategyApply(&build_strategy, graph.get()); + + // graph should not change for no pass here + ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + +} // namespace details +} // namespace framework +} // namespace paddle -- GitLab