From b6e84806ede67c4a0062837542c6dea693cb885c Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 26 Jul 2022 20:13:55 +0800 Subject: [PATCH] Merge kProgramDescs in GraphToProgram (#44526) --- paddle/fluid/framework/CMakeLists.txt | 10 ++-- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- paddle/fluid/framework/ir/graph_helper.cc | 34 ++++++++--- paddle/fluid/framework/ir/pass.cc | 57 +----------------- ...program_processing.cc => program_utils.cc} | 58 ++++++++++++++++++- .../{program_processing.h => program_utils.h} | 5 +- ...ocessing_test.cc => program_utils_test.cc} | 2 +- python/paddle/fluid/executor.py | 2 + tools/parallel_UT_rule.py | 4 +- 9 files changed, 98 insertions(+), 76 deletions(-) rename paddle/fluid/framework/{program_processing.cc => program_utils.cc} (67%) rename paddle/fluid/framework/{program_processing.h => program_utils.h} (89%) rename paddle/fluid/framework/{program_processing_test.cc => program_utils_test.cc} (99%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bd70e55ac45..b1308766c40 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -517,13 +517,13 @@ cc_test( DEPS op_call_stack) cc_library( - program_processing - SRCS program_processing.cc + program_utils + SRCS program_utils.cc DEPS proto_desc) cc_test( - program_processing_test - SRCS program_processing_test.cc - DEPS proto_desc program_processing) + program_utils_test + SRCS program_utils_test.cc + DEPS proto_desc program_utils) if(WITH_GPU) nv_test( diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 32131703524..680ae9a681a 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -67,7 +67,7 @@ cc_library( cc_library( graph_helper SRCS graph_helper.cc - DEPS graph scale_loss_grad_op_handle) + DEPS graph program_utils scale_loss_grad_op_handle) cc_library( pass SRCS pass.cc diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index d6bd2e4a80d..80568b77665 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -19,7 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/program_utils.h" DECLARE_bool(convert_all_blocks); PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir, @@ -559,20 +561,27 @@ static void GraphToBlock(const Graph &graph, << vars2remove.size() << " nodes"; } + std::vector vars_in_graph; + for (Node *node : graph.Nodes()) { + if (node->IsVar() && node->Var() && + node->GetVarNodeBlockId() == graph.GetBlockId()) { + vars_in_graph.emplace_back(*node->Var()->Proto()); + } + } + + // add vars_in_graph to blcok block->clear_vars(); std::unordered_set visited_vars; - for (Node *n : graph.Nodes()) { - if (n->IsVar()) { - if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && - !vars2remove.count(n->Var()->Name()) && - n->GetVarNodeBlockId() == graph.GetBlockId()) { - visited_vars.insert(n->Var()->Name()); - block->add_vars()->MergeFrom(*n->Var()->Proto()); - } + for (proto::VarDesc &var : vars_in_graph) { + const std::string &var_name = var.name(); + if (visited_vars.find(var_name) == visited_vars.end() && + vars2remove.find(var_name) == vars2remove.end()) { + block->add_vars()->MergeFrom(var); + visited_vars.insert(var_name); } } - block->clear_ops(); + block->clear_ops(); std::vector nodes; if (sort_kind != nullptr) { // Inference Memory Optimize relays on this branch. @@ -630,6 +639,13 @@ void GraphToProgram(const Graph &graph, } program->CopyFrom(program_pb); + + if (graph.Has(details::kProgramDescs)) { + details::ProgramDescs program_descs = + graph.Get(details::kProgramDescs); + VLOG(8) << "Merge main programs"; + MergePrograms(program, program_descs, /*append=*/false); + } } static std::vector> GetOpDependencies( diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index a92ceeaa548..35f72deab89 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/program_utils.h" namespace paddle { namespace framework { @@ -78,62 +79,6 @@ Graph *Pass::Apply(Graph *graph) const { return graph; } -template -static void VisitAllElements(Container &&container, - Visitor &&visitor, - bool reverse) { - if (reverse) { - std::for_each(container.rbegin(), container.rend(), visitor); - } else { - std::for_each(container.begin(), container.end(), visitor); - } -} - -static void MergePrograms(ProgramDesc *dst, - const details::ProgramDescs &srcs, - bool append) { - PADDLE_ENFORCE_NOT_NULL( - dst, platform::errors::InvalidArgument("Dst program must be provided.")); - bool reverse = !append; - - auto create_var_visitor = [dst](const ProgramDesc &src) { - PADDLE_ENFORCE_EQ( - src.Size(), - 1, - platform::errors::Unimplemented("MergePrograms can only support to " - "merge program with only one block.")); - const auto &src_block = src.Block(0); - auto *dst_block = dst->MutableBlock(0); - for (const auto *src_new_var : src_block.AllVars()) { - if (dst_block->FindVar(src_new_var->Name())) continue; - auto *dst_new_var = dst_block->Var(src_new_var->Name()); - *dst_new_var = *src_new_var; - VLOG(10) << "Create new variable " << dst_new_var->Name(); - } - }; - VisitAllElements(srcs, create_var_visitor, reverse); - - auto create_op_visitor = [dst, reverse](const ProgramDesc &src) { - auto ops = src.Block(0).AllOps(); - auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) { - auto *dst_block = dst->MutableBlock(0); - auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp(); - op->CopyFrom(*src_op); - VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type(); - // FIXME(zjl): some passes does not add VarDesc to program, - // we should fix this bug later... - for (const auto &in_var_name : op->InputArgumentNames()) { - dst_block->Var(in_var_name); - } - for (const auto &out_var_name : op->OutputArgumentNames()) { - dst_block->Var(out_var_name); - } - }; - VisitAllElements(ops, copy_op_visitor, reverse); - }; - VisitAllElements(srcs, create_op_visitor, reverse); -} - static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) { for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) { auto ops = main_program.Block(block_idx).AllOps(); diff --git a/paddle/fluid/framework/program_processing.cc b/paddle/fluid/framework/program_utils.cc similarity index 67% rename from paddle/fluid/framework/program_processing.cc rename to paddle/fluid/framework/program_utils.cc index ba50874c420..a32350a569d 100644 --- a/paddle/fluid/framework/program_processing.cc +++ b/paddle/fluid/framework/program_utils.cc @@ -12,13 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/program_processing.h" +#include "paddle/fluid/framework/program_utils.h" #include "paddle/fluid/framework/block_desc.h" namespace paddle { namespace framework { +template +inline void VisitAllElements(Container &&container, + Visitor &&visitor, + bool reverse) { + if (reverse) { + std::for_each(container.rbegin(), container.rend(), visitor); + } else { + std::for_each(container.begin(), container.end(), visitor); + } +} + +void MergePrograms(ProgramDesc *dst, + const std::vector &srcs, + bool append) { + PADDLE_ENFORCE_NOT_NULL( + dst, platform::errors::InvalidArgument("Dst program must be provided.")); + bool reverse = !append; + + auto create_var_visitor = [dst](const ProgramDesc &src) { + PADDLE_ENFORCE_EQ( + src.Size(), + 1, + platform::errors::Unimplemented("MergePrograms can only support to " + "merge program with only one block.")); + const auto &src_block = src.Block(0); + auto *dst_block = dst->MutableBlock(0); + for (const auto *src_new_var : src_block.AllVars()) { + if (dst_block->FindVar(src_new_var->Name())) continue; + auto *dst_new_var = dst_block->Var(src_new_var->Name()); + *dst_new_var = *src_new_var; + VLOG(10) << "Create new variable " << dst_new_var->Name(); + } + }; + VisitAllElements(srcs, create_var_visitor, reverse); + + auto create_op_visitor = [dst, reverse](const ProgramDesc &src) { + auto ops = src.Block(0).AllOps(); + auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) { + auto *dst_block = dst->MutableBlock(0); + auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp(); + op->CopyFrom(*src_op); + VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type(); + // FIXME(zjl): some passes does not add VarDesc to program, + // we should fix this bug later... + for (const auto &in_var_name : op->InputArgumentNames()) { + dst_block->Var(in_var_name); + } + for (const auto &out_var_name : op->OutputArgumentNames()) { + dst_block->Var(out_var_name); + } + }; + VisitAllElements(ops, copy_op_visitor, reverse); + }; + VisitAllElements(srcs, create_op_visitor, reverse); +} + void ProgramProcessor::GetInputsOutputsInBlock( const BlockDesc ¤t_block, std::set *inner_inputs, diff --git a/paddle/fluid/framework/program_processing.h b/paddle/fluid/framework/program_utils.h similarity index 89% rename from paddle/fluid/framework/program_processing.h rename to paddle/fluid/framework/program_utils.h index b495c31793d..4a276e80112 100644 --- a/paddle/fluid/framework/program_processing.h +++ b/paddle/fluid/framework/program_utils.h @@ -18,7 +18,9 @@ limitations under the License. */ namespace paddle { namespace framework { -class ProgramDesc; +void MergePrograms(ProgramDesc *dst, + const std::vector &srcs, + bool append); class ProgramProcessor { public: @@ -30,5 +32,6 @@ class ProgramProcessor { void AddDepToBlockOp(const BlockDesc &block); }; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/program_processing_test.cc b/paddle/fluid/framework/program_utils_test.cc similarity index 99% rename from paddle/fluid/framework/program_processing_test.cc rename to paddle/fluid/framework/program_utils_test.cc index 07e4deef0a8..051aa89e4b5 100644 --- a/paddle/fluid/framework/program_processing_test.cc +++ b/paddle/fluid/framework/program_utils_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/program_processing.h" +#include "paddle/fluid/framework/program_utils.h" #include "gtest/gtest-message.h" #include "gtest/gtest-test-part.h" diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8fd2283f0f9..9961c1a106b 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1445,9 +1445,11 @@ class Executor(object): if key not in self._executor_cache._cached_executors: # To apply IR pass, compile the Program to IrGraph and convert it back to Program if isinstance(program, compiler.CompiledProgram): + # print(f"Program before convert:\n {inner_program}", flush=True) program._compile(scope, self.place) ir_graph = framework.IrGraph(program._graph) inner_program = ir_graph.to_program() + # print(f"Program after convert:\n {inner_program}", flush=True) else: from paddle.incubate.autograd import prim_enabled, prim2orig if prim_enabled() and program == default_main_program(): diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 559f2d95b91..7b793c4cf83 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -470,7 +470,7 @@ HIGH_PARALLEL_JOB_NEW = [ 'cipher_utils_test', 'test_program_code', 'test_save_model_without_var', - 'program_processing_test', + 'program_utils_test', 'test_fleet_distributed_strategy', 'test_hybrid_parallel_topology', 'test_ascend_trigger', @@ -1719,7 +1719,7 @@ CPU_PARALLEL_JOB = [ 'test_sum_api', 'test_op_compat_sensible_pass', 'test_generate_pass_cc', - 'program_processing_test', + 'program_utils_test', 'build_strategy_test', 'test_fc_rnn_mkldnn_fuse_pass', 'scope_guard_test', -- GitLab