From ed61d67c737590ebf2819ca9770a9a6d4e294880 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 27 Mar 2019 23:37:17 -0500 Subject: [PATCH] Fix the interface of Pass::Apply (#16484) * modify the interface of Pass::Allay test=develop * Polish code test=develop * Fix Travis CI test=develop * fix Pass::Apply interface test=develop * Fix Travis CI test=develop --- .../framework/details/all_reduce_deps_pass.cc | 5 +- .../framework/details/all_reduce_deps_pass.h | 3 +- .../alloc_continuous_space_for_grad_pass.cc | 7 +-- .../fluid/framework/details/build_strategy.cc | 17 +++---- .../fluid/framework/details/build_strategy.h | 15 +++--- .../framework/details/eager_deletion_pass.cc | 8 ++-- .../details/fuse_all_reduce_op_pass.cc | 6 +-- .../framework/details/inplace_op_pass.cc | 9 ++-- .../fluid/framework/details/inplace_op_pass.h | 3 +- .../framework/details/memory_optimize_pass.cc | 7 +-- .../framework/details/memory_optimize_pass.h | 4 +- .../modify_op_lock_and_record_event_pass.cc | 4 +- .../modify_op_lock_and_record_event_pass.h | 3 +- .../details/multi_devices_graph_check_pass.cc | 6 +-- .../details/multi_devices_graph_pass.cc | 4 +- .../details/multi_devices_graph_pass.h | 3 +- .../details/multi_devices_graph_print_pass.cc | 2 + .../details/multi_devices_graph_print_pass.h | 5 +- .../details/parallel_ssa_graph_executor.cc | 2 +- .../framework/details/reference_count_pass.cc | 5 +- .../framework/details/reference_count_pass.h | 3 +- .../details/sequential_execution_pass.cc | 4 +- .../details/sequential_execution_pass.h | 3 +- .../details/while_op_eager_deletion_pass.cc | 4 +- ...anakin_fillconstant_elementwisemul_fuse.cc | 10 ++-- .../anakin_fillconstant_elementwisemul_fuse.h | 3 +- .../framework/ir/attention_lstm_fuse_pass.cc | 9 ++-- .../framework/ir/attention_lstm_fuse_pass.h | 3 +- .../ir/conv_affine_channel_fuse_pass.cc | 24 ++++------ .../ir/conv_affine_channel_fuse_pass.h | 6 +-- .../fluid/framework/ir/conv_bn_fuse_pass.cc | 31 +++++------- paddle/fluid/framework/ir/conv_bn_fuse_pass.h | 6 +-- .../ir/conv_elementwise_add2_act_fuse.cc | 6 +-- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 13 ++--- .../ir/conv_elementwise_add2_act_fuse_pass.h | 3 +- .../ir/conv_elementwise_add_act_fuse_pass.cc | 12 ++--- .../ir/conv_elementwise_add_act_fuse_pass.h | 3 +- .../ir/conv_elementwise_add_fuse_pass.cc | 10 ++-- .../ir/conv_elementwise_add_fuse_pass.h | 3 +- .../ir/embedding_fc_lstm_fuse_pass.cc | 14 +++--- .../ir/embedding_fc_lstm_fuse_pass.h | 3 +- paddle/fluid/framework/ir/fc_fuse_pass.cc | 13 +++-- paddle/fluid/framework/ir/fc_fuse_pass.h | 3 +- .../fluid/framework/ir/fc_fuse_pass_tester.cc | 2 +- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 22 ++++----- paddle/fluid/framework/ir/fc_gru_fuse_pass.h | 6 +-- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 21 ++++---- paddle/fluid/framework/ir/fc_lstm_fuse_pass.h | 6 +-- .../framework/ir/fuse_elewise_add_act_pass.cc | 48 +++++++++---------- .../framework/ir/fuse_elewise_add_act_pass.h | 20 ++++---- .../ir/fuse_relu_depthwise_conv_pass.cc | 24 +++++----- .../ir/fuse_relu_depthwise_conv_pass.h | 6 +-- .../framework/ir/graph_to_program_pass.cc | 6 +-- .../framework/ir/graph_to_program_pass.h | 2 +- .../ir/graph_to_program_pass_test.cc | 4 +- paddle/fluid/framework/ir/graph_viz_pass.cc | 13 ++--- paddle/fluid/framework/ir/graph_viz_pass.h | 4 +- .../ir/identity_scale_op_clean_pass.cc | 8 ++-- .../ir/identity_scale_op_clean_pass.h | 3 +- .../framework/ir/infer_clean_graph_pass.cc | 10 ++-- paddle/fluid/framework/ir/is_test_pass.cc | 4 +- paddle/fluid/framework/ir/is_test_pass.h | 3 +- .../fluid/framework/ir/is_test_pass_tester.cc | 2 +- .../framework/ir/lock_free_optimize_pass.cc | 11 ++--- .../framework/ir/lock_free_optimize_pass.h | 3 +- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 14 +++--- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.h | 3 +- .../conv_bias_mkldnn_fuse_pass_tester.cc | 4 +- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 12 ++--- .../conv_elementwise_add_mkldnn_fuse_pass.h | 5 +- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 4 +- .../ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc | 12 ++--- .../ir/mkldnn/conv_relu_mkldnn_fuse_pass.h | 3 +- .../conv_relu_mkldnn_fuse_pass_tester.cc | 2 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 15 +++--- .../framework/ir/mkldnn/cpu_quantize_pass.h | 3 +- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 2 +- .../ir/mkldnn/cpu_quantize_placement_pass.cc | 4 +- .../ir/mkldnn/cpu_quantize_placement_pass.h | 3 +- .../cpu_quantize_placement_pass_tester.cc | 2 +- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 13 ++--- .../ir/mkldnn/cpu_quantize_squash_pass.h | 3 +- .../mkldnn/cpu_quantize_squash_pass_tester.cc | 2 +- .../ir/mkldnn/depthwise_conv_mkldnn_pass.cc | 10 ++-- .../ir/mkldnn/depthwise_conv_mkldnn_pass.h | 3 +- .../depthwise_conv_mkldnn_pass_tester.cc | 2 +- .../ir/mkldnn/mkldnn_placement_pass.cc | 5 +- .../ir/mkldnn/mkldnn_placement_pass.h | 3 +- .../ir/mkldnn/mkldnn_placement_pass_tester.cc | 2 +- .../framework/ir/multi_batch_merge_pass.cc | 7 ++- .../framework/ir/multi_batch_merge_pass.h | 2 +- paddle/fluid/framework/ir/pass.cc | 14 +++--- paddle/fluid/framework/ir/pass.h | 9 ++-- paddle/fluid/framework/ir/pass_test.cc | 15 +++--- .../ir/repeated_fc_relu_fuse_pass.cc | 10 ++-- .../framework/ir/repeated_fc_relu_fuse_pass.h | 3 +- .../ir/runtime_context_cache_pass.cc | 4 +- .../framework/ir/runtime_context_cache_pass.h | 3 +- .../framework/ir/seq_concat_fc_fuse_pass.cc | 15 +++--- .../framework/ir/seq_concat_fc_fuse_pass.h | 3 +- .../ir/seqconv_eltadd_relu_fuse_pass.cc | 10 ++-- .../ir/seqconv_eltadd_relu_fuse_pass.h | 3 +- .../framework/ir/seqpool_concat_fuse_pass.cc | 10 ++-- .../framework/ir/seqpool_concat_fuse_pass.h | 3 +- .../ir/seqpool_concat_fuse_pass_tester.cc | 2 +- .../simplify_anakin_detection_pattern_pass.cc | 11 ++--- .../simplify_anakin_detection_pattern_pass.h | 3 +- .../framework/ir/squared_mat_sub_fuse_pass.cc | 10 ++-- .../framework/ir/squared_mat_sub_fuse_pass.h | 3 +- .../framework/ir/sync_batch_norm_pass.cc | 4 +- .../fluid/framework/ir/sync_batch_norm_pass.h | 3 +- .../ir/sync_batch_norm_pass_tester.cc | 2 +- .../ir/transpose_flatten_concat_fuse_pass.cc | 10 ++-- .../ir/transpose_flatten_concat_fuse_pass.h | 3 +- paddle/fluid/framework/parallel_executor.cc | 40 ++++++---------- .../inference/analysis/ir_pass_manager.cc | 4 +- .../ir_passes/anakin_subgraph_pass.cc | 6 +-- .../analysis/ir_passes/anakin_subgraph_pass.h | 3 +- .../ir_passes/tensorrt_subgraph_pass.cc | 17 +++---- .../ir_passes/tensorrt_subgraph_pass.h | 3 +- .../passes/ir_graph_to_program_pass.cc | 4 +- paddle/fluid/pybind/pybind.cc | 4 +- 122 files changed, 370 insertions(+), 539 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc index 98a74d630..d93c84606 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) { return nullptr; } -std::unique_ptr AllReduceDepsPass::ApplyImpl( - std::unique_ptr graph) const { +void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const { auto graph_ops = ir::FilterByNodeWrapper(*graph); // get vars order @@ -131,8 +130,6 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( VLOG(10) << "pre_op:" << pre_op->DebugString() << ", op:" << op->DebugString(); } - - return graph; } } // namespace details diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.h b/paddle/fluid/framework/details/all_reduce_deps_pass.h index e8b910898..4ed373658 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.h +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.h @@ -24,8 +24,7 @@ namespace details { // TODO(gongwb): overlap allreduce with backward computation. class AllReduceDepsPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace details diff --git a/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc b/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc index fbc8bbf56..e195e93fb 100644 --- a/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc +++ b/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc @@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype = class AllocContinuousSpaceForGradPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { + void ApplyImpl(ir::Graph *graph) const override { ir::Graph &result = *graph; auto &places = Get>(kPlaces); @@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { if (params_grads.size() == 0) { VLOG(10) << "Doesn't find gradients"; - return std::move(graph); + return; } std::unordered_map vars; @@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, fused_var_name, params_grads); - - return std::move(graph); } template diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 5d9db2375..078403f30 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0; } -std::unique_ptr BuildStrategy::Apply( - std::unique_ptr graph, - const std::vector &places, - const std::string &loss_var_name, const std::vector &local_scopes, - const size_t &nranks, +ir::Graph *BuildStrategy::Apply(ir::Graph *graph, + const std::vector &places, + const std::string &loss_var_name, + const std::vector &local_scopes, + const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { + const bool use_cuda, + platform::NCCLContextMap *nccl_ctxs) const { #else - const bool use_cuda) const { + const bool use_cuda) const { #endif // Create a default one if not finalized by user. CreatePassesFromStrategy(false); @@ -265,7 +266,7 @@ std::unique_ptr BuildStrategy::Apply( } } VLOG(3) << "Start Apply Pass " << pass->Type(); - graph = pass->Apply(std::move(graph)); + graph = pass->Apply(graph); VLOG(3) << "Finish Apply Pass " << pass->Type(); } return graph; diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 4b599fb91..9587a6f0f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -120,16 +120,15 @@ struct BuildStrategy { // Apply the passes built by the pass_builder_. The passes will be // applied to the Program and output an ir::Graph. - std::unique_ptr Apply(std::unique_ptr graph, - const std::vector &places, - const std::string &loss_var_name, - const std::vector &local_scopes, - const size_t &nranks, + ir::Graph *Apply(ir::Graph *graph, const std::vector &places, + const std::string &loss_var_name, + const std::vector &local_scopes, + const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - const bool use_cuda, - platform::NCCLContextMap *nccl_ctxs) const; + const bool use_cuda, + platform::NCCLContextMap *nccl_ctxs) const; #else - const bool use_cuda) const; + const bool use_cuda) const; #endif // If set true, ParallelExecutor would build the main_program into multiple diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc index a6baa2613..622a59b4c 100644 --- a/paddle/fluid/framework/details/eager_deletion_pass.cc +++ b/paddle/fluid/framework/details/eager_deletion_pass.cc @@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars( class EagerDeletionPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph *graph) const override; }; -std::unique_ptr EagerDeletionPass::ApplyImpl( - std::unique_ptr graph) const { +void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { auto &ref_cnts = Get>(kRuntimeReferenceCount); PADDLE_ENFORCE(ref_cnts.empty(), @@ -240,7 +238,7 @@ std::unique_ptr EagerDeletionPass::ApplyImpl( auto while_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); - return while_op_eager_deletion_pass->Apply(std::move(graph)); + while_op_eager_deletion_pass->Apply(graph); } } // namespace details diff --git a/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc b/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc index f226491c9..31efd78ad 100644 --- a/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc +++ b/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc @@ -28,8 +28,7 @@ namespace details { class FuseAllReduceOpPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { + void ApplyImpl(ir::Graph *graph) const override { ir::Graph &result = *graph; auto &places = Get>(kPlaces); @@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass { VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size(); if (all_reduce_ops.size() == 0) { - return std::move(graph); + return; } PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(), @@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass { group_all_reduce_ops, &result); #endif } - return std::move(graph); } void InsertFusedAllReduce(const std::vector &places, diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 88f26b416..afbda33b0 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const { } } -std::unique_ptr InplacePass::ApplyImpl( - std::unique_ptr graph) const { +void InplacePass::ApplyImpl(ir::Graph* graph) const { var_nodes_.clear(); - view_.Build(graph.get()); + view_.Build(graph); InitSSAGraphNodes(); auto cnt = 0; @@ -155,11 +154,9 @@ std::unique_ptr InplacePass::ApplyImpl( VLOG(4) << "Handle op " << cnt++ << ": " << op->Name(); if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) continue; - TryInplaceOpInputOutput(op, graph.get()); + TryInplaceOpInputOutput(op, graph); } // graph->ResolveHazard(var_nodes_); - - return graph; } void InplacePass::InplaceModifyDesc(const std::string& var, diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index 01964ba8f..fbec973dd 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -69,8 +69,7 @@ class InplacePass : public ir::Pass { InplacePass(); protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; void InitSSAGraphNodes() const; diff --git a/paddle/fluid/framework/details/memory_optimize_pass.cc b/paddle/fluid/framework/details/memory_optimize_pass.cc index 80720af32..ddaef2060 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass.cc @@ -44,8 +44,7 @@ namespace paddle { namespace framework { namespace details { -std::unique_ptr MemoryOptimizePass::ApplyImpl( - std::unique_ptr graph) const { +void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const { auto nodes = graph->Nodes(); CollectSkipVarsSet(nodes); @@ -113,7 +112,7 @@ std::unique_ptr MemoryOptimizePass::ApplyImpl( cfg_->RenameVarInCFGGraph(var_name, cache_name, idx); RenameVarInGraphDesc(var_name, cache_name, idx); - RenameVarInGraphNode(var_name, cache_name, idx, graph.get()); + RenameVarInGraphNode(var_name, cache_name, idx, graph); pool_.Erase(cache_name); } } @@ -128,8 +127,6 @@ std::unique_ptr MemoryOptimizePass::ApplyImpl( } } graph->ResolveHazard(var_nodes_); - - return graph; } void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { diff --git a/paddle/fluid/framework/details/memory_optimize_pass.h b/paddle/fluid/framework/details/memory_optimize_pass.h index 593ffc10f..ce94890b3 100644 --- a/paddle/fluid/framework/details/memory_optimize_pass.h +++ b/paddle/fluid/framework/details/memory_optimize_pass.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -35,8 +36,7 @@ namespace details { class MemoryOptimizePass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; // fill the variable map(var_nodes) by version. void InitSSAGraphNodes() const; diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc index 67aad9f94..ae363f963 100644 --- a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc @@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( return true; } -std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( - std::unique_ptr ir_graph) const { +void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const { auto all_ops = ir::FilterByNodeWrapper(*ir_graph); OpGraphView graph_view(all_ops); for (auto &op : all_ops) { @@ -49,7 +48,6 @@ std::unique_ptr ModifyOpLockAndRecordEventPass::ApplyImpl( << compute_op->DebugString(); } } - return ir_graph; } } // namespace details diff --git a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h index b54e1b318..54d52d624 100644 --- a/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h +++ b/paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h @@ -23,8 +23,7 @@ namespace details { class ModifyOpLockAndRecordEventPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace details diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index a4bb1e26d..9859b04de 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -23,10 +23,8 @@ namespace details { class SSAGraghBuilderWithChecker : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return graph; + void ApplyImpl(ir::Graph *graph) const override { + PADDLE_ENFORCE(IsValidGraph(graph)); } bool IsValidGraph(const ir::Graph *graph) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 8c61684c9..f80a098bf 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -153,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); } -std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( - std::unique_ptr graph) const { +void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { Init(); CheckGraph(*graph); std::vector sorted_ops = SortOperations(*graph); @@ -236,7 +235,6 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( AddOutputToLeafOps(&result); result.Erase(kGraphOps); - return graph; } void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 8bfd7b9bf..884089df3 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -36,8 +36,7 @@ namespace details { class MultiDevSSAGraphBuilderBase : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph *graph) const override; virtual void Init() const; diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index e82eb104f..34c38ea81 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" +#include #include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.h b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h index b06c87a5c..6d57d75e8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/details/multi_devices_helper.h" @@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class SSAGraghBuilderWithPrinter : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { + void ApplyImpl(ir::Graph* graph) const override { std::unique_ptr fout( new std::ofstream(Get(kGraphvizPath))); PADDLE_ENFORCE(fout->good()); Get("graph_printer").Print(*graph, *fout); - return graph; } }; diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 2afac3243..137e0dd77 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( auto seq_allreduce_pass = ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); for (size_t i = 0; i < graphs_.size(); ++i) { - graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); + graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release())); } // set the correct size of thread pool to each device. diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index c218e55b7..25337872c 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency( } } -std::unique_ptr ReferenceCountPass::ApplyImpl( - std::unique_ptr graph) const { +void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { auto &ref_cnts = Get>(kGlobalReferenceCount); auto &last_live_ops_of_vars = Get>(kLastLiveOpsOfVars); @@ -342,8 +341,6 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( // Just skip this corner case } } - - return graph; } } // namespace details diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h index bcbef0273..7bb01ee61 100644 --- a/paddle/fluid/framework/details/reference_count_pass.h +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -23,8 +23,7 @@ namespace details { class ReferenceCountPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace details diff --git a/paddle/fluid/framework/details/sequential_execution_pass.cc b/paddle/fluid/framework/details/sequential_execution_pass.cc index 0b53a76e7..839f8dc43 100644 --- a/paddle/fluid/framework/details/sequential_execution_pass.cc +++ b/paddle/fluid/framework/details/sequential_execution_pass.cc @@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) { op1->Outputs() == op2->Outputs(); } -std::unique_ptr SequentialExecutionPass::ApplyImpl( - std::unique_ptr graph) const { +void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const { // FIXME(zjl): Insert dependencies between some distributed ops may cause // the multi_devices_graph_pass fails. So we skip these ops here. // Indeed, maybe we should not insert dependencies between these ops @@ -98,7 +97,6 @@ std::unique_ptr SequentialExecutionPass::ApplyImpl( VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name() << " and " << op_node_list[i]->Name(); } - return graph; } } // namespace details diff --git a/paddle/fluid/framework/details/sequential_execution_pass.h b/paddle/fluid/framework/details/sequential_execution_pass.h index ea3034877..7d6a4f4cc 100644 --- a/paddle/fluid/framework/details/sequential_execution_pass.h +++ b/paddle/fluid/framework/details/sequential_execution_pass.h @@ -23,8 +23,7 @@ namespace details { class SequentialExecutionPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace details diff --git a/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc b/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc index fd6b6dd22..8f7c99f12 100644 --- a/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/details/while_op_eager_deletion_pass.cc @@ -23,8 +23,7 @@ namespace details { class WhileOpEagerDeletionPass : public ir::Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override { + void ApplyImpl(ir::Graph *graph) const override { auto all_ops = ir::FilterByNodeWrapper(*graph); // Find all while_op and while_grad_op @@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass { operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( while_ops, while_grad_ops); } - return graph; } }; diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc index 83b0da0c0..39077f642 100644 --- a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc +++ b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc @@ -29,10 +29,9 @@ namespace ir { GET_IR_NODE(elementwise_mul); \ GET_IR_NODE(elementwise_mul_out); -std::unique_ptr AnakinFillconstantElementwisemulFuse::ApplyImpl( - std::unique_ptr graph) const { +void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() @@ -69,12 +68,11 @@ std::unique_ptr AnakinFillconstantElementwisemulFuse::ApplyImpl( IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph.get(), + GraphSafeRemoveNodes(graph, {fill_constant, fill_constant_out, elementwise_mul}); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } } // namespace ir diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h index fa95143d3..14c07c588 100644 --- a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h +++ b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h @@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase { virtual ~AnakinFillconstantElementwisemulFuse() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index a9897e0bb..5a82d7927 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" #include +#include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, // Parameters -std::unique_ptr AttentionLSTMFusePass::ApplyImpl( - std::unique_ptr graph) const { +void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const { PDPattern external_pattern, subblock_pattern; // Use the following variables to tell whether this model is RNN1. @@ -269,12 +269,11 @@ std::unique_ptr AttentionLSTMFusePass::ApplyImpl( } } if (count < specified_vars.size()) { - return graph; + return; } // Continue to fuse. - FindWhileOp(graph.get()); - return graph; + FindWhileOp(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h index 39b0585d3..47ed9f039 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h @@ -22,8 +22,7 @@ namespace ir { class AttentionLSTMFusePass : public FusePassBase { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index a7bfb8cf1..fecc159ad 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, weights_array_2d.colwise() *= scale_array; } -std::unique_ptr ConvAffineChannelFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); +void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE(scope); @@ -139,7 +138,7 @@ std::unique_ptr ConvAffineChannelFusePass::ApplyImpl( desc.SetAttr("axis", 1); auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel}); + GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel}); IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); @@ -147,16 +146,14 @@ std::unique_ptr ConvAffineChannelFusePass::ApplyImpl( found_conv_ac_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_ac_count); - return graph; } -std::unique_ptr ConvEltwiseAddAffineChannelFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); +void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE(scope); @@ -199,7 +196,7 @@ std::unique_ptr ConvEltwiseAddAffineChannelFusePass::ApplyImpl( eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetOutput("Out", std::vector({ac_out->Name()})); - GraphSafeRemoveNodes(graph.get(), + GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel, eltwise_out}); IR_NODE_LINK_TO(eltwise, ac_out); @@ -207,9 +204,8 @@ std::unique_ptr ConvEltwiseAddAffineChannelFusePass::ApplyImpl( found_conv_ac_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_ac_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h index 8c3c8b56c..d607020a4 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h @@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase { virtual ~ConvAffineChannelFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph*) const override; const std::string name_scope_{"conv_affine_channel_fuse"}; }; @@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { virtual ~ConvEltwiseAddAffineChannelFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph*) const override; const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"}; }; diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 04765dd14..876a99964 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope, weights_array_2d.colwise() *= variance_array; } -std::unique_ptr ConvBNFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); +void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE(scope); @@ -187,7 +186,7 @@ std::unique_ptr ConvBNFusePass::ApplyImpl( std::vector({bn_out->Name()})); GraphSafeRemoveNodes( - graph.get(), + graph, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance}); @@ -203,10 +202,9 @@ std::unique_ptr ConvBNFusePass::ApplyImpl( desc.SetAttr("axis", 1); auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes( - graph.get(), - {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, - bn_variance_out, bn_saved_mean, bn_saved_variance}); + GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance, + batch_norm, bn_mean_out, bn_variance_out, + bn_saved_mean, bn_saved_variance}); IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); @@ -215,16 +213,14 @@ std::unique_ptr ConvBNFusePass::ApplyImpl( } }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_bn_count); - return graph; } -std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); +void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE(scope); @@ -274,7 +270,7 @@ std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( eltwise->Op()->SetOutput("Out", std::vector({bn_out->Name()})); GraphSafeRemoveNodes( - graph.get(), + graph, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); @@ -283,10 +279,9 @@ std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( found_conv_bn_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_bn_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h index cf425a273..837a48ed7 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase { virtual ~ConvBNFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"conv_bn_fuse"}; }; @@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { virtual ~ConvEltwiseAddBNFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; }; diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc index 6e9905b7e..99bc5fe8c 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc @@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc( return *desc.Proto(); } -std::unique_ptr ConvElementwiseAddActFusePass::ApplyImpl( - std::unique_ptr graph) const { +void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_act_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( @@ -95,7 +94,6 @@ std::unique_ptr ConvElementwiseAddActFusePass::ApplyImpl( elementwise_add_out}); }; gpd(graph.get(), handler); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index c6121777e..b4d6f683c 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc( return *desc.Proto(); } -std::unique_ptr ConvElementwiseAdd2ActFusePass::ApplyImpl( - std::unique_ptr graph) const { +void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add2_act_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( @@ -92,12 +91,10 @@ std::unique_ptr ConvElementwiseAdd2ActFusePass::ApplyImpl( // Delete the unneeded nodes. GraphSafeRemoveNodes( - graph.get(), - {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1, - elementwise_add_out, elementwise_add_out_1, act_op}); + graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1, + elementwise_add_out, elementwise_add_out_1, act_op}); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h index 9259a4ac5..ea9e465d8 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h @@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase { virtual ~ConvElementwiseAdd2ActFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index fe3b4fca7..ba0a2fb96 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc( return *desc.Proto(); } -std::unique_ptr ConvElementwiseAddActFusePass::ApplyImpl( - std::unique_ptr graph) const { +void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_act_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() @@ -88,12 +87,11 @@ std::unique_ptr ConvElementwiseAddActFusePass::ApplyImpl( IR_NODE_LINK_TO(new_conv_op, act_out); // Output // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op, - elementwise_add_out, act_op}); + GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op, + elementwise_add_out, act_op}); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h index 9c0b50f15..8b34c3551 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h @@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase { virtual ~ConvElementwiseAddActFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 476c9dbc3..8c491d4f5 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -30,10 +30,9 @@ namespace ir { GET_IR_NODE(elementwise_add_in_y); \ GET_IR_NODE(elementwise_add_out); -std::unique_ptr ConvElementwiseAddFusePass::ApplyImpl( - std::unique_ptr graph) const { +void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "conv_elementwise_add_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() @@ -76,11 +75,10 @@ std::unique_ptr ConvElementwiseAddFusePass::ApplyImpl( IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op}); + GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op}); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } } // namespace ir diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h index bf43bd5ce..66a562cdd 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h @@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase { virtual ~ConvElementwiseAddFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index ba11f19c9..3a6bbe65b 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h" #include #include +#include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/operators/math/blas.h" @@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, // Remove unneeded nodes. // TODO(jczaja): Proper removing of lookup table std::unordered_set marked_nodes( - //{lookup_table, mul, lstm, elementwise_add, fc_bias, W}); + // {lookup_table, mul, lstm, elementwise_add, fc_bias, W}); {mul, lstm, elementwise_add, fc_bias}); GraphSafeRemoveNodes(graph, marked_nodes); } else { @@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, return fusion_count; } -std::unique_ptr EmbeddingFCLSTMFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), - true /*with_fc_bias*/); + int fusion_count = + BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h index fde2a0a4e..65cb44397 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h @@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase { virtual ~EmbeddingFCLSTMFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"embedding_fc_lstm_fuse"}; }; diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 12b31da01..ca008763b 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include +#include #include #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/enforce.h" @@ -22,10 +23,9 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr FCFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("fc_fuse", graph.get()); +void FCFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("fc_fuse", graph); std::unordered_set nodes2delete; @@ -61,7 +61,7 @@ std::unique_ptr FCFusePass::ApplyImpl( desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); desc.SetType("fc"); auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); + GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); PADDLE_ENFORCE(subgraph.count(x)); IR_NODE_LINK_TO(subgraph.at(x), fc_node); @@ -72,10 +72,9 @@ std::unique_ptr FCFusePass::ApplyImpl( found_fc_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_fc_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index 783a052ed..0a0fcd2da 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase { virtual ~FCFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 4e1e4e27f..affe50691 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -73,7 +73,7 @@ TEST(FCFusePass, basic) { int pre_nodes = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int after_nodes = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index a902b0b50..5f660c6d3 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" #include +#include #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { @@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, // Create New OpDesc auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h, Node* bias, Node* hidden, Node* fc_bias) { - OpDesc op_desc; op_desc.SetType("fusion_gru"); @@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, return fusion_count; } -std::unique_ptr MulGRUFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), - false /*with_fc_bias*/); + int fusion_count = + BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/); AddStatis(fusion_count); - return graph; } -std::unique_ptr FCGRUFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), - true /*with_fc_bias*/); + int fusion_count = + BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.h b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h index e359a3289..e11cdac7e 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h @@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase { virtual ~FCGRUFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"fc_gru_fuse"}; }; @@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase { virtual ~MulGRUFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"fc_nobias_gru_fuse"}; }; diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index f5c286486..babeba961 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include +#include #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { @@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, return fusion_count; } -std::unique_ptr MulLstmFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), - false /*with_fc_bias*/); + int fusion_count = + BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/); AddStatis(fusion_count); - return graph; } -std::unique_ptr FCLstmFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), - true /*with_fc_bias*/); + int fusion_count = + BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h index 21482615a..5dea7c91a 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase { virtual ~FCLstmFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"fc_lstm_fuse"}; }; @@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase { virtual ~MulLstmFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"fc_nobias_lstm_fuse"}; }; diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 648acc4a7..bd4967316 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h" #include #include +#include +#include #include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" @@ -23,29 +25,25 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr FuseElewiseAddActPass::ApplyImpl( - std::unique_ptr graph) const { +void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { std::unordered_set act_types = {"relu", "scale"}; - graph = FuseActElewiseAdd(std::move(graph), act_types); - graph = FuseElewiseAddAct(std::move(graph), act_types); + graph = FuseActElewiseAdd(graph, act_types); + graph = FuseElewiseAddAct(graph, act_types); // backward { std::unordered_set in_place_act_types = {"relu_grad"}; - graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types); + graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types); } // Remove the removable intermediate_out. - RemoveIntermediateOut(graph.get()); - - return graph; + RemoveIntermediateOut(graph); } // ele_add(x, act(y)) -std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddAct( - std::unique_ptr graph, - const std::unordered_set &act_types) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("elewise_add_act", graph.get()); +ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( + ir::Graph *graph, const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("elewise_add_act", graph); GraphPatternDetector gpd; auto *x = gpd.mutable_pattern() @@ -86,18 +84,17 @@ std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddAct( found_elewise_add_act_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_elewise_add_act_count); return graph; } // act(ele_add(x,y)) -std::unique_ptr FuseElewiseAddActPass::FuseActElewiseAdd( - std::unique_ptr graph, - const std::unordered_set &act_types) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("act_elewise_add", graph.get()); +ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( + ir::Graph *graph, const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("act_elewise_add", graph); GraphPatternDetector gpd; auto *x = gpd.mutable_pattern() @@ -137,7 +134,7 @@ std::unique_ptr FuseElewiseAddActPass::FuseActElewiseAdd( found_elewise_add_act_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_elewise_add_act_count); return graph; @@ -146,11 +143,10 @@ std::unique_ptr FuseElewiseAddActPass::FuseActElewiseAdd( // the backward of act(ele_add(x,y)) // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] -std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( - std::unique_ptr graph, - const std::unordered_set &act_types) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("elewise_add_act_grad", graph.get()); +ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( + ir::Graph *graph, const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("elewise_add_act_grad", graph); GraphPatternDetector gpd; auto *d_act_out = gpd.mutable_pattern() @@ -217,7 +213,7 @@ std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( found_elewise_add_act_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_elewise_add_act_count); return graph; diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h index 0fee52744..dc73f1fda 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" @@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase { virtual ~FuseElewiseAddActPass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph *graph) const override; - std::unique_ptr FuseElewiseAddAct( - std::unique_ptr graph, - const std::unordered_set &act_types) const; + ir::Graph *FuseElewiseAddAct( + ir::Graph *graph, const std::unordered_set &act_types) const; - std::unique_ptr FuseActElewiseAdd( - std::unique_ptr graph, - const std::unordered_set &act_types) const; + ir::Graph *FuseActElewiseAdd( + ir::Graph *graph, const std::unordered_set &act_types) const; - std::unique_ptr FuseElewiseAddActInplaceGrad( - std::unique_ptr graph, - const std::unordered_set &act_types) const; + ir::Graph *FuseElewiseAddActInplaceGrad( + ir::Graph *graph, const std::unordered_set &act_types) const; /** * Remove the removable intermediate_out. diff --git a/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.cc b/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.cc index fe844caed..c4e6b6e6a 100644 --- a/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.cc +++ b/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h" #include #include +#include #include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" @@ -23,20 +24,18 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr FuseReluDepthwiseConvPass::ApplyImpl( - std::unique_ptr graph) const { - graph = FuseReluDepthwiseConv(std::move(graph), true); - graph = FuseReluDepthwiseConv(std::move(graph), false); - return graph; +void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const { + graph = FuseReluDepthwiseConv(graph, true); + graph = FuseReluDepthwiseConv(graph, false); } -std::unique_ptr FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( - std::unique_ptr graph, bool only_forward) const { - PADDLE_ENFORCE(graph.get()); +ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( + ir::Graph *graph, bool only_forward) const { + PADDLE_ENFORCE(graph); if (only_forward) - FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get()); + FusePassBase::Init("relu_depthwise_conv_only_forward", graph); else - FusePassBase::Init("relu_depthwise_conv", graph.get()); + FusePassBase::Init("relu_depthwise_conv", graph); /* x ---act--> y ---layer-> z +----------+ @@ -144,10 +143,9 @@ std::unique_ptr FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( } count++; }; - gpd(graph.get(), handler); - GraphSafeRemoveNodes(graph.get(), need_removed_nodes); + gpd(graph, handler); + GraphSafeRemoveNodes(graph, need_removed_nodes); AddStatis(count); - return graph; } diff --git a/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h b/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h index efb49b830..d37c153dd 100644 --- a/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h +++ b/paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h @@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase { virtual ~FuseReluDepthwiseConvPass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; - std::unique_ptr FuseReluDepthwiseConv( - std::unique_ptr graph, bool only_forward) const; + void ApplyImpl(ir::Graph* graph) const override; + ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc index 3372dcd18..b0d056f2c 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include +#include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -26,8 +28,7 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr GraphToProgramPass::ApplyImpl( - std::unique_ptr graph) const { +void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const { // Remove the unneeded variables after memory optimization. std::unordered_set vars2remove; if (graph->Has(kGraphToProgramVarsToRemove)) { @@ -73,7 +74,6 @@ std::unique_ptr GraphToProgramPass::ApplyImpl( } program.CopyFrom(*program_pb); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h index 4c36c3a5d..52c8f4e0f 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass.h +++ b/paddle/fluid/framework/ir/graph_to_program_pass.h @@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; class GraphToProgramPass : public Pass { protected: - std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc index 5d51d9751..5ee6b8a5f 100644 --- a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc +++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc @@ -14,7 +14,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_to_program_pass.h" +#include #include +#include #include #include "gtest/gtest.h" #include "paddle/fluid/framework/program_desc.h" @@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) { ProgramDesc compiled_prog; pass->SetNotOwned("program", &compiled_prog); - pass->Apply(std::move(g)); + pass->Apply(g.get()); std::vector ops = compiled_prog.Block(0).AllOps(); EXPECT_EQ(ops[0]->Type(), "op1"); EXPECT_EQ(ops[1]->Type(), "op2"); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 87a28a2a6..f4df4cfeb 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include +#include #include - -#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/string/printf.h" @@ -38,8 +38,7 @@ std::string FormatName(const Node* node) { } } // namespace -std::unique_ptr GraphVizPass::ApplyImpl( - std::unique_ptr graph) const { +void GraphVizPass::ApplyImpl(ir::Graph* graph) const { const std::string graph_viz_path = Get(kGraphVizPath); VLOG(3) << "draw IR graph viz to " << graph_viz_path; std::unique_ptr fout(new std::ofstream(graph_viz_path)); @@ -82,7 +81,7 @@ std::unique_ptr GraphVizPass::ApplyImpl( {Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"), Dot::Attr("fillcolor", "yellow")}); - auto marked_nodes = ConsumeMarkedNodes(graph.get()); + auto marked_nodes = ConsumeMarkedNodes(graph); // Create nodes for (const Node* n : graph->Nodes()) { std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")"; @@ -115,8 +114,6 @@ std::unique_ptr GraphVizPass::ApplyImpl( } sout << dot.Build(); - - return graph; } GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( @@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( } // namespace paddle REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) - .RequirePassAttr(paddle::framework::ir::kGraphVizPath); \ No newline at end of file + .RequirePassAttr(paddle::framework::ir::kGraphVizPath); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h index e64916a5b..7091aa6a9 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.h +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -34,8 +35,7 @@ class GraphVizPass : public Pass { using marked_nodes_t = std::unordered_set; protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; // Tell whether there are any marked nodes in the graph. Consume the // corresponding attribute. diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc index 5bdc0c5fa..a39901e63 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc @@ -20,9 +20,8 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr IdentityScaleOpCleanPass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init("identity_scale_op_clean", graph.get()); +void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("identity_scale_op_clean", graph); // pre_op -> scale_in -> scale_op -> scale_out // -> @@ -72,8 +71,7 @@ std::unique_ptr IdentityScaleOpCleanPass::ApplyImpl( IR_NODE_LINK_TO(pre_op_var, scale_out_var); }; - detector(graph.get(), handler); - return graph; + detector(graph, handler); } } // namespace ir diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.h b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.h index 6da592561..d66b41125 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.h +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.h @@ -22,8 +22,7 @@ namespace ir { class IdentityScaleOpCleanPass : public FusePassBase { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; private: virtual ~IdentityScaleOpCleanPass() = default; diff --git a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc index 6607c026a..d76924116 100644 --- a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc +++ b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc @@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase { virtual ~InferCleanGraphPass() {} protected: - std::unique_ptr ApplyImpl(std::unique_ptr graph) const { - FusePassBase::Init("original_graph", graph.get()); - PADDLE_ENFORCE(graph.get()); + void ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("original_graph", graph); + PADDLE_ENFORCE(graph); auto is_valid_node = [](Node* x) { return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); @@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase { } } - GraphSafeRemoveNodes(graph.get(), invalid_nodes); + GraphSafeRemoveNodes(graph, invalid_nodes); AddStatis(valid_op); - - return graph; } void CleanEdges(std::vector* nodes, diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 57cc98e2c..bf6fe999c 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -20,8 +20,7 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr IsTestPass::ApplyImpl( - std::unique_ptr graph) const { +void IsTestPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " "for activations and pooling."; auto op_list = {"pool2d", "sigmoid", "logsigmoid", @@ -47,7 +46,6 @@ std::unique_ptr IsTestPass::ApplyImpl( } } } - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/is_test_pass.h b/paddle/fluid/framework/ir/is_test_pass.h index 99e76ca4a..80cedbf9f 100644 --- a/paddle/fluid/framework/ir/is_test_pass.h +++ b/paddle/fluid/framework/ir/is_test_pass.h @@ -22,8 +22,7 @@ namespace ir { class IsTestPass : public Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc index 9696441a2..3fa543c62 100644 --- a/paddle/fluid/framework/ir/is_test_pass_tester.cc +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -97,7 +97,7 @@ TEST(IsTestPass, basic) { auto pass = PassRegistry::Instance().Get("is_test_pass"); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); for (auto* node : graph->Nodes()) { if (node->IsOp()) { diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc index 92e897ca9..05d23961a 100644 --- a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc @@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum"; // other optimizers later. const char kOptimizerType[] = "sgd"; -std::unique_ptr LockFreeOptimizePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); +void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); // We could collect all weights' name from SGD, where // W1 <- SGD(W0, Grad0) @@ -92,14 +91,14 @@ std::unique_ptr LockFreeOptimizePass::ApplyImpl( // find the forward op related to the backward op ir::Node* forward_op = - FindForwardOpViaBackwardOp(graph.get(), backward_op); + FindForwardOpViaBackwardOp(graph, backward_op); VLOG(3) << "Found forward_op " << forward_op->Name(); PADDLE_ENFORCE(forward_op); Node* new_optimizer_node = CreateNewSGDNode( - graph.get(), forward_op, backward_op, node, opt_node); + graph, forward_op, backward_op, node, opt_node); PADDLE_ENFORCE(new_optimizer_node); } @@ -140,8 +139,6 @@ std::unique_ptr LockFreeOptimizePass::ApplyImpl( } } } - - return graph; } ir::Node* LockFreeOptimizePass::CreateNewSGDNode( diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.h b/paddle/fluid/framework/ir/lock_free_optimize_pass.h index f9157b10d..d1718857a 100644 --- a/paddle/fluid/framework/ir/lock_free_optimize_pass.h +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.h @@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass { virtual ~LockFreeOptimizePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; private: // Create a new sgd node via current optimizer node diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 5d0b294f6..8ef3993b0 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, return vec_y; } -std::unique_ptr ConvBiasFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); +void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE(scope); @@ -99,7 +98,7 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( conv->Op()->SetOutput("Output", std::vector({eltwise_out->Name()})); - GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out}); + GraphSafeRemoveNodes(graph, {eltwise, conv_out}); IR_NODE_LINK_TO(conv, eltwise_out); } else { @@ -123,14 +122,13 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); IR_NODE_LINK_TO(conv_bias_node, eltwise_out); - GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out}); + GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out}); } found_conv_bias_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_bias_count); - return graph; } } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 0ef5c177b..84106d065 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase { virtual bool is_conv3d() const { return false; } protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"conv_bias_mkldnn_fuse"}; }; /* diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 38b7fe520..ff7f9190f 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -13,10 +13,10 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" +#include #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/platform/place.h" -#include #include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { @@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) { int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index fb3db8134..ef7874c1c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -16,8 +16,8 @@ #include #include #include +#include #include - #include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { @@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( get_node_from_elementwise_add); } -graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { + FusePassBase::Init(name_scope_, graph); auto fused_graph_with_stats = FuseConvAsY( name_scope_, - FuseConvAsX( - name_scope_, - FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0)))); + FuseConvAsX(name_scope_, + FuseProjectionConv(name_scope_, std::make_pair(graph, 0)))); std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; AddStatis(fused_graph_with_stats.second); - return graph; } } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 6629dae42..9bf1ae607 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -27,7 +28,7 @@ namespace paddle { namespace framework { namespace ir { -using graph_ptr = std::unique_ptr; +using graph_ptr = ir::Graph*; using GraphWithStats = std::pair; void CorrectGraphEdges(Graph* graph, Node* from, Node* to); @@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { virtual ~ResidualConnectionMKLDNNFusePass() {} protected: - std::unique_ptr ApplyImpl(graph_ptr graph) const; + void ApplyImpl(graph_ptr graph) const; const std::string name_scope_{"residual_connection_fuse_pass"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 433d89d8d..8a13596cd 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from, auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); EXPECT_TRUE(is_reachable(graph)(from, to)); @@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); EXPECT_TRUE(is_reachable(graph)("a", "g")); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc index 4f4605398..dd0fb4560 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc @@ -21,10 +21,9 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr ConvReLUFusePass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); +void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("conv_relu_mkldnn_fuse", graph); GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() @@ -56,7 +55,7 @@ std::unique_ptr ConvReLUFusePass::ApplyImpl( OpDesc* desc = conv->Op(); desc->SetOutput("Output", std::vector({relu_out->Name()})); desc->SetAttr("fuse_relu", true); - GraphSafeRemoveNodes(graph.get(), {relu, conv_out}); + GraphSafeRemoveNodes(graph, {relu, conv_out}); PADDLE_ENFORCE(subgraph.count(conv_input)); IR_NODE_LINK_TO(conv, relu_out); @@ -64,10 +63,9 @@ std::unique_ptr ConvReLUFusePass::ApplyImpl( found_conv_relu_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_conv_relu_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h index fe585bd7c..2174c22db 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h @@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase { virtual ~ConvReLUFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc index 06d56f622..67a995705 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc @@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) { int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index b3a8c2089..dff98e523 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); } -std::unique_ptr CPUQuantizePass::ApplyImpl( - std::unique_ptr graph) const { +void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init(name_scope_, graph.get()); + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); PADDLE_ENFORCE(param_scope()); - QuantizeConv(graph.get(), false /* with_residual_data */); - QuantizeConv(graph.get(), true /* with_residual_data */); - QuantizePool(graph.get()); - - return graph; + QuantizeConv(graph, false /* with_residual_data */); + QuantizeConv(graph, true /* with_residual_data */); + QuantizePool(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 9873bb04e..a178c4dc3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase { virtual ~CPUQuantizePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; void QuantizeConv(Graph* graph, bool with_residual_data = false) const; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 0d0ed9890..8716a412e 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc index 511003dce..79a8ac68b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc @@ -20,8 +20,7 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr CPUQuantizePlacementPass::ApplyImpl( - std::unique_ptr graph) const { +void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Marks operators which are to be quantized."; const auto& excluded_ids_list = Get>("quantize_excluded_op_ids"); @@ -43,7 +42,6 @@ std::unique_ptr CPUQuantizePlacementPass::ApplyImpl( } } } - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h index ef3861b24..008a462dc 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.h @@ -25,8 +25,7 @@ namespace ir { */ class CPUQuantizePlacementPass : public Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc index 11d72a56b..ba4d281f8 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass_tester.cc @@ -94,7 +94,7 @@ void MainTest(std::initializer_list quantize_enabled_op_types, pass->Set("quantize_excluded_op_ids", new std::unordered_set(quantize_excluded_op_ids)); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); unsigned use_quantizer_true_count = 0; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 6e74cc778..debbbd644 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash( found_squash_count); } -std::unique_ptr CPUQuantizeSquashPass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("cpu_quantize_squash_pass", graph.get()); +void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("cpu_quantize_squash_pass", graph); std::unordered_map nodes_keep_counter; - FindNodesToKeep(graph.get(), &nodes_keep_counter); - Squash(graph.get(), &nodes_keep_counter); - - return graph; + FindNodesToKeep(graph, &nodes_keep_counter); + Squash(graph, &nodes_keep_counter); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index b823a2cef..e873994c5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase { virtual ~CPUQuantizeSquashPass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; /* * For each dequantize's output find the number of operators it is an input to diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 3cf51d97a..fda337066 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { int original_nodes_num = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index 7851e8c84..e854559ae 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -25,10 +25,9 @@ namespace ir { auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); -std::unique_ptr DepthwiseConvMKLDNNPass::ApplyImpl( - std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get()); +void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("depthwise_conv_mkldnn_pass", graph); GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -45,9 +44,8 @@ std::unique_ptr DepthwiseConvMKLDNNPass::ApplyImpl( found_depthwise_conv_mkldnn_count++; }; - gpd(graph.get(), handler); + gpd(graph, handler); AddStatis(found_depthwise_conv_mkldnn_count); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h index 8ca6a7325..ca314afde 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h @@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase { virtual ~DepthwiseConvMKLDNNPass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc index 1783e3322..f2dfbc84a 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc @@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) { counters before{1, 1, 1, 1}; - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); // initialize counters before loop counters after{0, 0, 0, 0}; diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index ccac65f3b..500419e4b 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -14,13 +14,13 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include +#include namespace paddle { namespace framework { namespace ir { -std::unique_ptr MKLDNNPlacementPass::ApplyImpl( - std::unique_ptr graph) const { +void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies MKL-DNN placement strategy."; const auto& op_types_list = Get>("mkldnn_enabled_op_types"); @@ -37,7 +37,6 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( } } } - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h index c071d9aed..ffa62273e 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h @@ -26,8 +26,7 @@ namespace ir { */ class MKLDNNPlacementPass : public Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc index b6ec7e4d6..5885f327e 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass_tester.cc @@ -97,7 +97,7 @@ void MainTest(std::initializer_list mkldnn_enabled_op_types, pass->Set("mkldnn_enabled_op_types", new std::unordered_set(mkldnn_enabled_op_types)); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); unsigned use_mkldnn_true_count = 0; diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc index 9e77f98e9..dcc48fb93 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -16,8 +16,9 @@ #include #include +#include +#include #include - #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_proto_maker.h" @@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc( return *var_desc; } -std::unique_ptr BatchMergePass::ApplyImpl( - std::unique_ptr graph) const { +void BatchMergePass::ApplyImpl(ir::Graph* graph) const { int num_repeats = Get(kNumRepeats); std::vector forward_backward_ops; std::vector optimize_ops; @@ -325,7 +325,6 @@ std::unique_ptr BatchMergePass::ApplyImpl( } result.ResolveHazard(created); - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.h b/paddle/fluid/framework/ir/multi_batch_merge_pass.h index c1e5aef20..a89616683 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.h +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.h @@ -36,7 +36,7 @@ class BatchMergePass : public Pass { virtual ~BatchMergePass() {} protected: - std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; + void ApplyImpl(Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 33ccee6aa..c0ed0519b 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -18,8 +18,8 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -std::unique_ptr Pass::Apply(std::unique_ptr graph) const { - PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); +Graph* Pass::Apply(Graph* graph) const { + PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), "Required pass atrribute %s not set.", attr); @@ -28,16 +28,16 @@ std::unique_ptr Pass::Apply(std::unique_ptr graph) const { PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", attr); } - auto* native_graph = graph.get(); - auto applied_graph = ApplyImpl(std::move(graph)); + auto* native_graph = graph; + ApplyImpl(graph); // TODO(panyx0718): Add more verifications. - PADDLE_ENFORCE(!HasCircle(*applied_graph), + PADDLE_ENFORCE(!HasCircle(*graph), "Illegal Pass. Generated graph shouldn't has cycle."); - PADDLE_ENFORCE(applied_graph.get() == native_graph, + PADDLE_ENFORCE(graph == native_graph, "Pass::Apply() cannot delete the passed graph and shouldn't " "return a new graph.(For the need of pybind11)"); applied_ = true; - return applied_graph; + return graph; } PassRegistry& PassRegistry::Instance() { diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 27746ff14..6cbe9a821 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -16,8 +16,10 @@ limitations under the License. */ #include #include +#include #include - +#include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" @@ -44,7 +46,7 @@ class Pass { std::string Type() const { return type_; } - std::unique_ptr Apply(std::unique_ptr graph) const; + Graph *Apply(Graph *graph) const; // Get a reference to the attributed previously set. template @@ -98,9 +100,8 @@ class Pass { } protected: - virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + virtual void ApplyImpl(Graph *graph) const { LOG(FATAL) << "Calling virtual Pass not implemented."; - return graph; } private: diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 6ad7d1df8..87e3c9641 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" +#include #include +#include #include "gtest/gtest.h" #include "paddle/fluid/framework/ir/graph.h" @@ -39,7 +41,7 @@ void BuildCircleGraph(Graph* g) { class TestPass : public Pass { protected: - std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + void ApplyImpl(ir::Graph* graph) const { graph->Set("copy_test_pass_attr", new int); graph->Set("copy_test_graph_attr", new int); @@ -48,7 +50,6 @@ class TestPass : public Pass { int test_graph_attr = graph->Get("test_graph_attr"); graph->Get("copy_test_graph_attr") = test_graph_attr + 1; - return graph; } }; @@ -58,7 +59,7 @@ TEST(PassTest, TestPassAttrCheck) { std::unique_ptr graph(new Graph(prog)); std::string exception; try { - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); } catch (paddle::platform::EnforceNotMet e) { exception = std::string(e.what()); } @@ -69,7 +70,7 @@ TEST(PassTest, TestPassAttrCheck) { pass->SetNotOwned("test_pass_attr", &val); try { - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); } catch (paddle::platform::EnforceNotMet e) { exception = std::string(e.what()); } @@ -78,14 +79,14 @@ TEST(PassTest, TestPassAttrCheck) { graph.reset(new Graph(prog)); graph->Set("test_graph_attr", new int); graph->Get("test_graph_attr") = 1; - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); // Allow apply more than once. graph.reset(new Graph(prog)); graph->Set("test_graph_attr", new int); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); pass = PassRegistry::Instance().Get("test_pass"); pass->SetNotOwned("test_pass_attr", &val); @@ -94,7 +95,7 @@ TEST(PassTest, TestPassAttrCheck) { graph->Set("test_graph_attr", new int); graph->Get("test_graph_attr") = 2; try { - auto tmp = pass->Apply(std::move(graph)); + pass->Apply(graph.release()); } catch (paddle::platform::EnforceNotMet e) { exception = std::string(e.what()); } diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 84a4ff2de..00263b8a3 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" #include // for max #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" @@ -365,17 +366,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, return fusion_count; } -std::unique_ptr RepeatedFCReluFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); int fusion_count = 0; for (int i = MAX_NUM_FC; i > 1; --i) { fusion_count += - BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); + BuildFusion(graph, name_scope_ + "/" + std::to_string(i), i); } AddStatis(fusion_count); - - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h index ede0bea07..ae777bcce 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h @@ -31,8 +31,7 @@ class RepeatedFCReluFusePass : public FusePassBase { virtual ~RepeatedFCReluFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"repeated_fc_relu_fuse"}; }; diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc index 67b29512c..c7cf9b0dc 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -20,15 +20,13 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr RuntimeContextCachePass::ApplyImpl( - std::unique_ptr graph) const { +void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Applies Runtime Context Cache strategy."; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { n->Op()->SetAttr(kEnableCacheRuntimeContext, true); } } - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.h b/paddle/fluid/framework/ir/runtime_context_cache_pass.h index a6cf1a9ae..e4783166e 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.h +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.h @@ -23,8 +23,7 @@ namespace ir { class RuntimeContextCachePass : public Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 012e68036..b230c5016 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h" #include #include - +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" -#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h" #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { @@ -178,9 +178,8 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { return fc_out; } -std::unique_ptr SeqConcatFcFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init("seq_concat_fc_fuse", graph.get()); +void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("seq_concat_fc_fuse", graph); GraphPatternDetector detector; auto* pattern = detector.mutable_pattern(); auto* concat_out = BuildSeqExpandConcatPattern(pattern); @@ -194,8 +193,8 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl( int fuse_count{0}; - detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph) { + detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { VLOG(4) << "get one concat pattern"; // fc GET_NODE(fc_w, detector.pattern()); @@ -246,8 +245,6 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl( }); AddStatis(fuse_count); - - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h index 06e18f9dc..d68840a55 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h @@ -27,8 +27,7 @@ class SeqConcatFcFusePass : public FusePassBase { virtual ~SeqConcatFcFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc index 0a1f65d27..3fd368741 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h" #include +#include #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { @@ -83,14 +84,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { return fusion_count; } -std::unique_ptr SeqConvEltAddReluFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope()); + int fusion_count = BuildFusion(graph, name_scope_, param_scope()); AddStatis(fusion_count); - - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h index c36c6b76a..fde9b586c 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h @@ -28,8 +28,7 @@ class SeqConvEltAddReluFusePass : public FusePassBase { virtual ~SeqConvEltAddReluFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"seqconv_eltadd_relu_fuse"}; }; diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 63a0c24f2..4ac379eb0 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h" #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" @@ -194,17 +195,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, return fusion_count; } -std::unique_ptr SeqPoolConcatFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); +void SeqPoolConcatFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); int fusion_count = 0; for (int i = MAX_CONCAT_INPUTS; i > 0; --i) { fusion_count += - BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); + BuildFusion(graph, name_scope_ + "/" + std::to_string(i), i); } AddStatis(fusion_count); - - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h index a5db3528d..40a9edc5e 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h @@ -42,8 +42,7 @@ class SeqPoolConcatFusePass : public FusePassBase { virtual ~SeqPoolConcatFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"seqpool_concat_fuse"}; }; diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc index 35d1d5129..d36680385 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc @@ -59,7 +59,7 @@ std::unique_ptr GetNumNodesOfBeforeAfter( const std::string& pass_type = "seqpool_concat_fuse_pass") { auto pass = PassRegistry::Instance().Get(pass_type); *before = graph->Nodes().size(); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); *after = graph->Nodes().size(); return graph; } diff --git a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc index 84fb8063e..e1ddc4447 100644 --- a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc +++ b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc @@ -24,11 +24,11 @@ namespace framework { namespace ir { template -std::unique_ptr SimplifyAnakinDetectionPatternPass::ApplyImpl( - std::unique_ptr graph) const { +void SimplifyAnakinDetectionPatternPass::ApplyImpl( + ir::Graph *graph) const { const std::string pattern_name = "simplify_anakin_detection_pattern_pass" + std::to_string(times); - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; std::vector input_nodes; @@ -207,11 +207,10 @@ std::unique_ptr SimplifyAnakinDetectionPatternPass::ApplyImpl( multiclass_nms_out->inputs.push_back(detection_out_op); // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph.get(), delete_nodes); + GraphSafeRemoveNodes(graph, delete_nodes); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } template class SimplifyAnakinDetectionPatternPass<1>; diff --git a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h index 2338e4c38..e4a266cbe 100644 --- a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h +++ b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.h @@ -32,8 +32,7 @@ class SimplifyAnakinDetectionPatternPass : public FusePassBase { virtual ~SimplifyAnakinDetectionPatternPass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index 78c8cabb1..42f4a91a6 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h" #include +#include #include #include "paddle/fluid/framework/lod_tensor.h" @@ -362,13 +363,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { return fusion_count; } -std::unique_ptr SquaredMatSubFusePass::ApplyImpl( - std::unique_ptr graph) const { - FusePassBase::Init(name_scope_, graph.get()); - int fusion_count = BuildFusion(graph.get(), name_scope_); +void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + int fusion_count = BuildFusion(graph, name_scope_); AddStatis(fusion_count); - - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h index c21ba65c4..b6165a512 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h @@ -31,8 +31,7 @@ class SquaredMatSubFusePass : public FusePassBase { virtual ~SquaredMatSubFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"squared_mat_sub_fuse"}; }; diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc index b37003991..f4f924a60 100644 --- a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc @@ -21,8 +21,7 @@ namespace paddle { namespace framework { namespace ir { -std::unique_ptr SyncBatchNormPass::ApplyImpl( - std::unique_ptr graph) const { +void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Use synchronous batch norm"; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { @@ -35,7 +34,6 @@ std::unique_ptr SyncBatchNormPass::ApplyImpl( } } } - return graph; } } // namespace ir diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.h b/paddle/fluid/framework/ir/sync_batch_norm_pass.h index 51cce3dca..694fae749 100644 --- a/paddle/fluid/framework/ir/sync_batch_norm_pass.h +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.h @@ -23,8 +23,7 @@ namespace ir { class SyncBatchNormPass : public Pass { protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc index 9c94c1746..894f96050 100644 --- a/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc @@ -60,7 +60,7 @@ TEST(IsTestPass, basic) { auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass"); - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); for (auto* node : graph->Nodes()) { if (node->IsOp()) { diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index cab69c408..61c12d4b6 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -26,11 +26,10 @@ namespace framework { namespace ir { template -std::unique_ptr TransposeFlattenConcatFusePass::ApplyImpl( - std::unique_ptr graph) const { +void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { const std::string pattern_name = "transpose_flatten" + std::to_string(times) + "_concat_fuse"; - FusePassBase::Init(pattern_name, graph.get()); + FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; std::vector input_nodes; @@ -117,11 +116,10 @@ std::unique_ptr TransposeFlattenConcatFusePass::ApplyImpl( concat_out->inputs.push_back(new_conv_op); // Delete the unneeded nodes. - GraphSafeRemoveNodes(graph.get(), delete_nodes); + GraphSafeRemoveNodes(graph, delete_nodes); }; - gpd(graph.get(), handler); - return graph; + gpd(graph, handler); } template class TransposeFlattenConcatFusePass<1>; diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h index a7d18ec86..366d26d80 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -30,8 +30,7 @@ class TransposeFlattenConcatFusePass : public FusePassBase { virtual ~TransposeFlattenConcatFusePass() {} protected: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 20a8c47d5..ab0947c63 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -77,8 +77,7 @@ class ParallelExecutorPrivate { } } - std::unique_ptr PrepareGCAndRefCnts( - std::unique_ptr graph, size_t max_memory_size); + ir::Graph *PrepareGCAndRefCnts(ir::Graph *graph, size_t max_memory_size); inline bool HasGarbageCollectors() const { return !gcs_.empty(); } @@ -118,8 +117,8 @@ class ParallelExecutorPrivate { details::GarbageCollectorMap gcs_; }; -std::unique_ptr ParallelExecutorPrivate::PrepareGCAndRefCnts( - std::unique_ptr graph, size_t max_memory_size) { +ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( + ir::Graph *graph, size_t max_memory_size) { for (size_t i = 0; i < places_.size(); ++i) { auto &place = places_[i]; if (gcs_.count(place) > 0) { @@ -161,7 +160,7 @@ std::unique_ptr ParallelExecutorPrivate::PrepareGCAndRefCnts( &global_ref_cnts_); ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, &last_live_ops_of_vars); - graph = ref_cnt_pass->Apply(std::move(graph)); + graph = ref_cnt_pass->Apply(graph); VLOG(10) << "ReferenceCountPass Applied"; auto eager_deletion_pass = @@ -172,10 +171,9 @@ std::unique_ptr ParallelExecutorPrivate::PrepareGCAndRefCnts( eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, &last_live_ops_of_vars); eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); - graph = eager_deletion_pass->Apply(std::move(graph)); + graph = eager_deletion_pass->Apply(graph); VLOG(10) << "EagerDeletionPass Applied"; } - return graph; } @@ -220,13 +218,11 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } } - std::unique_ptr temp_owned_graph(graph); - // FIXME(Yancey1989): parallel graph mode get better performance // in GPU allreduce distributed training. Need an elegant way to // choice the execution strategy. - build_strategy.enable_parallel_graph_ = EnableParallelGraphExecution( - *temp_owned_graph, exec_strategy, build_strategy); + build_strategy.enable_parallel_graph_ = + EnableParallelGraphExecution(*graph, exec_strategy, build_strategy); if (build_strategy.enable_parallel_graph_) VLOG(0) << "The Executor would execute the graph by ParallelGraph " "Execution which can get better performance," @@ -304,27 +300,21 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - - temp_owned_graph = build_strategy.Apply( - std::move(temp_owned_graph), member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, member_->use_cuda_, - member_->nccl_ctxs_.get()); + graph = build_strategy.Apply(graph, member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, + member_->use_cuda_, member_->nccl_ctxs_.get()); #else - temp_owned_graph = build_strategy.Apply( - std::move(temp_owned_graph), member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, member_->use_cuda_); + graph = build_strategy.Apply(graph, member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, + member_->use_cuda_); #endif auto max_memory_size = GetEagerDeletionThreshold(); VLOG(10) << "Eager Deletion Threshold " << static_cast(max_memory_size) / (1 << 30); if (max_memory_size >= 0) { - graph = member_ - ->PrepareGCAndRefCnts(std::move(temp_owned_graph), - static_cast(max_memory_size)) - .release(); - } else { - graph = temp_owned_graph.release(); + graph = member_->PrepareGCAndRefCnts(graph, + static_cast(max_memory_size)); } // Step 3. Create vars in each scope. Passes may also create new vars. diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 7a96ac11d..78e502c67 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -140,7 +140,7 @@ std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { if (pass->Type() != "graph_viz_pass") { PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); } - graph = pass->Apply(std::move(graph)); + graph.reset(pass->Apply(graph.release())); } return graph; } @@ -156,7 +156,7 @@ framework::proto::ProgramDesc IRPassManager::AcquireProgram( desc.CopyFrom(*program->Proto()); pass->SetNotOwned("program", &desc); auto *the_graph = graph->release(); - *graph = pass->Apply(std::unique_ptr(the_graph)); + graph->reset(pass->Apply(the_graph)); return *desc.Proto(); } diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index 12deed253..9e05aa5c1 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -35,8 +35,8 @@ namespace analysis { using framework::ir::Node; -std::unique_ptr analysis::AnakinSubgraphPass::ApplyImpl( - std::unique_ptr graph) const { +void analysis::AnakinSubgraphPass::ApplyImpl( + framework::ir::Graph *graph) const { framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph.get()); auto teller = [](const framework::ir::Node *node) { @@ -72,8 +72,6 @@ std::unique_ptr analysis::AnakinSubgraphPass::ApplyImpl( framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); graph->Set(framework::ir::kRepetitiveParamAttr, new std::vector(repetitive_params)); - - return graph; } std::string GenerateAnakinEngineKey(const std::set &engine_inputs, diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h index c13b9ecda..e80b8bb61 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h @@ -29,8 +29,7 @@ namespace analysis { class AnakinSubgraphPass : public framework::ir::FusePassBase { public: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(framework::ir::Graph *graph) const override; private: void CreateAnakinOp(framework::ir::Node *x, framework::ir::Graph *graph, diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 593994032..ef5872c52 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -31,16 +31,16 @@ namespace analysis { using framework::ir::Node; -std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( - std::unique_ptr graph) const { - framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph.get()); +void analysis::TensorRtSubgraphPass::ApplyImpl( + framework::ir::Graph *graph) const { + framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); auto teller = [](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); }; - SubGraphFuser fuser(graph.get(), teller, + SubGraphFuser fuser(graph, teller, Get("min_subgraph_size") /*min subgraph size*/); fuser(); @@ -52,12 +52,11 @@ std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( for (auto *node : graph->Nodes()) { if (node->IsOp() && !Agent(node).subgraph()->empty()) { - CreateTensorRTOp(node, graph.get(), graph_param_names, - &repetitive_params); + CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params); std::unordered_set nodes2remove( Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); - framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); } } @@ -67,11 +66,9 @@ std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( nodes2remove.insert(node); } } - framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); graph->Set(framework::ir::kRepetitiveParamAttr, new std::vector(repetitive_params)); - - return graph; } std::string GenerateEngineKey(const std::set &engine_inputs, diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h index f043670c5..f530a5a0b 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h @@ -28,8 +28,7 @@ namespace analysis { class TensorRtSubgraphPass : public framework::ir::FusePassBase { public: - std::unique_ptr ApplyImpl( - std::unique_ptr graph) const override; + void ApplyImpl(framework::ir::Graph *graph) const override; private: void CreateTensorRTOp(framework::ir::Node *x, framework::ir::Graph *graph, diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc index 6b3d80fce..35df396fe 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" +#include #include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/program_desc.h" @@ -37,8 +38,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { framework::ProgramDesc desc; desc.CopyFrom(*argument->main_program().Proto()); pass->SetNotOwned("program", &desc); - auto thegraph = pass->Apply(std::move(graph)); - thegraph.release(); // the argument still own the graph. + pass->Apply(graph.release()); // the argument still own the graph. argument->SetIrAnalyzedProgram( new framework::proto::ProgramDesc(*desc.Proto())); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3b0939ef8..d4c85fd0c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1046,9 +1046,7 @@ All parameter, weight, gradient are variables in Paddle. int val) { self.Set(name, new int(val)); }) .def("type", &ir::Pass::Type) .def("apply", [](ir::Pass &self, std::shared_ptr graph) { - std::unique_ptr origin_graph(graph.get()); - auto optim_graph = self.Apply(std::move(origin_graph)); - optim_graph.release(); + self.Apply(graph.get()); }); py::class_> pb( -- GitLab