diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 44dde061851fc288a7d6c5f914864ef922b61c32..5c5eac54cee7b8c268a942fa99bcdcee18ce0273 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -52,7 +52,6 @@ pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) -pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc deleted file mode 100644 index d76924116f6d6202557a0d76cfcdadba0a3a6de6..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" - -namespace paddle { -namespace framework { -namespace ir { - -class InferCleanGraphPass : public FusePassBase { - public: - virtual ~InferCleanGraphPass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const { - FusePassBase::Init("original_graph", graph); - PADDLE_ENFORCE(graph); - - auto is_valid_node = [](Node* x) { - return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); - }; - - std::unordered_set invalid_nodes; - int valid_op = 0; - for (auto* node : graph->Nodes()) { - PADDLE_ENFORCE_NOT_NULL(node); - if (is_valid_node(node)) { - invalid_nodes.insert(node); - } else if (node->IsOp()) { - // Collect all the operators to help tracking number of operators. - ++valid_op; - } - } - - GraphSafeRemoveNodes(graph, invalid_nodes); - - AddStatis(valid_op); - } - - void CleanEdges(std::vector* nodes, - const std::unordered_set& to_remove) const { - auto it = std::remove_if(nodes->begin(), nodes->end(), - [&](Node* x) { return to_remove.count(x); }); - nodes->erase(it, nodes->end()); - } -}; - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(infer_clean_graph_pass, - paddle::framework::ir::InferCleanGraphPass); diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index d82a063d8808591a7ebf6b70e7421a401ce969f7..71fdb5570c7c6fca56a302b5d2deee4bd1a8f9f8 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -31,6 +31,9 @@ void Analyzer::RunAnalysis(Argument *argument) { "analsis_passes is not valid in the argument."); for (auto &pass : argument->analysis_passes()) { string::PrettyLogH1("--- Running analysis [%s]", pass); + if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass") + continue; + auto *ptr = PassRegistry::Global().Retreive(pass); PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass); ptr->Run(argument); diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index c814ce454840a2c6f3829599b86c9e127d07e4f4..489345da49a232e7fb21bd44c1ecf34cf1e4fe8f 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -30,7 +30,7 @@ using namespace framework; // NOLINT TEST(Analyzer, analysis_without_tensorrt) { Argument argument; argument.SetModelDir(FLAGS_inference_model_dir); - argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); + argument.SetEnableAnalysisOptim(false); argument.SetUseGPU(false); argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", "ir_params_sync_among_devices_pass"}); @@ -41,10 +41,10 @@ TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) { Argument argument; + argument.SetEnableAnalysisOptim(false); argument.SetTensorRtMaxBatchSize(3); argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetModelDir(FLAGS_inference_model_dir); - argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); argument.SetUseGPU(false); argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", "ir_params_sync_among_devices_pass"}); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 3fcf579cebc11ef511bfd5e715ffbbfe7143cde2..1aceb4f469e3d9c6e163ede1dad48e01cef3d95c 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -62,6 +62,9 @@ struct Argument { using anakin_max_shape_t = std::map>; bool Has(const std::string& key) const { return valid_fields_.count(key); } + // If we set the model using config.SetModelBuffer, + // the model and parameter will occupy additional CPU resources. + // Use this interface to release these resources. void PartiallyRelease() { if (Has("model_program_path")) { if (Has("model_from_memory") && model_from_memory()) { @@ -130,6 +133,7 @@ struct Argument { DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); + DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool); // The overall graph to work on. DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index 860dc309760d67cc20a638286fc6409e4c93ee65..1c878d66ba97a13e14d341d08943dfe8c78228a4 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_p cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass) +cc_library(ir_graph_clean_pass SRCS ir_graph_clean_pass.cc DEPS analysis_pass) cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass @@ -14,6 +15,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS memory_optim_pass inference_op_replace_pass ir_graph_to_program_pass + ir_graph_clean_pass ) set(analysis_deps ${analysis_deps} diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc index ef7d13da89dbdcd17fc10feffcdbca76559df0df..86ced982d34d80e38e24650c0d687152ab5e3dcb 100644 --- a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc @@ -20,9 +20,9 @@ namespace inference { namespace analysis { void InferenceOpReplacePass::RunImpl(Argument* argument) { - if (!argument->use_gpu()) return; std::unordered_map replaced_map{ {"conditional_block", "conditional_block_infer"}, + {"merge_lod_tensor", "merge_lod_tensor_infer"}, }; auto& graph = argument->main_graph(); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f888a28da0416b41a87b551208fbe109f54d844 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h" +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void IrInferCleanGraphPass::RunImpl(Argument* argument) { + auto& graph = argument->main_graph(); + auto is_valid_node = [](framework::ir::Node* x) { + return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); + }; + + std::unordered_set invalid_nodes; + int valid_op = 0; + for (auto* node : graph.Nodes()) { + PADDLE_ENFORCE_NOT_NULL(node); + if (is_valid_node(node)) { + invalid_nodes.insert(node); + } else if (node->IsOp()) { + ++valid_op; + } + } + + GraphSafeRemoveNodes(&graph, invalid_nodes); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a9d58aa2f4cbb5d135221b0d02c633f6f78c8190 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/inference/analysis/analysis_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +class IrInferCleanGraphPass : public AnalysisPass { + public: + void RunImpl(Argument *argument) override; + + std::string repr() const override { return "ir_graph_clean_pass"; } +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index c894acfd48cc5be683a75a218e1d77f62bedaee6..6ecaf08f7d3329e63b0f71da46a66c67eb5c53be 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -109,10 +109,16 @@ int DataTypeToSpace(framework::proto::VarType_Type type) { void MemoryOptimizePass::CollectVarMemorySize( space_table_t* space_table) const { const int fake_batch_size = 1; + auto valid_var = [&](framework::ir::Node* node) -> bool { - std::set invalid_op = {"while", "conditional_block", + std::set invalid_op = {"while", + "conditional_block", "tensorrt_engine", - "conditional_block_infer"}; + "conditional_block_infer", + "merge_lod_tensor_infer", + "merge_lod_tensor", + "equal", + "lod_reset"}; for (auto* tmp : node->inputs) { CHECK(tmp->IsOp()); std::string op_type = tmp->Op()->Type(); diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 5a907303b4d3ba2d1404de7c5b82527b384aa3de..90e285da09990c2fb5fb551e06ddf044a238e37d 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -75,6 +75,7 @@ class MemoryOptimizePass : public AnalysisPass { int sort_kind) const; void CollectVarMemorySize(space_table_t *space_table) const; + void CollectVarMemorySize0(space_table_t *space_table) const; void CollectVarMemorySize( const std::unordered_map &batch_var_ave_dim, diff --git a/paddle/fluid/inference/analysis/passes/passes.cc b/paddle/fluid/inference/analysis/passes/passes.cc index 97debcec565696b2c87456ec7406788c8aa0661a..ca0b25c29d495dc0e71e69a6d7d2a10f0f8c2254 100644 --- a/paddle/fluid/inference/analysis/passes/passes.cc +++ b/paddle/fluid/inference/analysis/passes/passes.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" +#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" @@ -32,6 +33,8 @@ PassRegistry::PassRegistry() { std::unique_ptr(new IrAnalysisPass)); passes_.emplace("ir_graph_build_pass", std::unique_ptr(new IrGraphBuildPass)); + passes_.emplace("ir_graph_clean_pass", + std::unique_ptr(new IrInferCleanGraphPass)); passes_.emplace("memory_optimize_pass", std::unique_ptr(new MemoryOptimizePass)); passes_.emplace( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a5e8821c1a0cd7340fe47e2db5b9643473d9d58a..df62c1fc9a65b54c87ad638ee752344be9966aea 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -135,7 +135,6 @@ bool AnalysisPredictor::PrepareProgram( const std::shared_ptr &program) { if (!program) { if (!LoadProgramDesc()) return false; - // If not cloned, the parameters should be loaded. // If config_.ir_optim() is True, parameters is loaded in // OptimizeInferenceProgram(), but other persistable variables @@ -145,17 +144,10 @@ bool AnalysisPredictor::PrepareProgram( // So in both case, create persistable variables at first. executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); - // Optimize the program, and load parameters and modify them in the - // scope_. - // This will change the scope_ address. - if (config_.ir_optim()) { - status_ir_optim_enabled_ = true; - OptimizeInferenceProgram(); - } else { - // Load parameters - LOG(INFO) << "load parameters "; - LoadParameters(); - } + // if enable_ir_optim_ is false, + // the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will + // not be executed. + OptimizeInferenceProgram(); } else { // If the program is passed from external, no need to optimize it, this // logic is used in the clone scenario. @@ -396,6 +388,7 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, void AnalysisPredictor::PrepareArgument() { argument_.SetUseGPU(config_.use_gpu()); argument_.SetGPUDeviceId(config_.gpu_device_id()); + argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetStaticMemoryOptim(config_.static_memory_optim_); argument_.SetStaticMemoryOptimForceUpdate( @@ -467,8 +460,6 @@ void AnalysisPredictor::PrepareArgument() { // NOTE All the members in AnalysisConfig should be copied to Argument. void AnalysisPredictor::OptimizeInferenceProgram() { - status_program_optimized_ = true; - PrepareArgument(); Analyzer().Run(&argument_); diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 7a366b10c7b1ccf6e9c7a1be69aedc8186ff3f05..0727c7b908b81e66373c9c2a3885edb51b540018 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -178,10 +178,8 @@ class AnalysisPredictor : public PaddlePredictor { private: // Some status here that help to determine the status inside the predictor. - bool status_program_optimized_{false}; bool status_is_cloned_{false}; bool status_use_gpu_{false}; - bool status_ir_optim_enabled_{false}; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 44b1b8071de9d0e825ea4c8ee895c44b8951f14f..e990b2c7736ae51a1ac2ba2fd15362012288b9bb 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -44,7 +44,6 @@ TEST(AnalysisPredictor, analysis_off) { ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); // ir is turned off, so program shouldn't be optimized. - ASSERT_FALSE(predictor->status_program_optimized_); LOG(INFO) << "scope parameters " << predictor->scope_->LocalVarNames().size(); // 2. Dummy Input Data @@ -76,8 +75,6 @@ TEST(AnalysisPredictor, analysis_on) { ASSERT_TRUE(predictor->sub_scope_); ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); - // ir is turned on, so program should be optimized. - ASSERT_TRUE(predictor->status_program_optimized_); // 2. Dummy Input Data int64_t data[4] = {1, 2, 3, 4}; PaddleTensor tensor; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 179e002f7dd8ed6ecb3c38147d24d62d1b519305..8ebdbf1673fb2648b84d0451a6bc64426dfb6ce7 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -400,13 +400,14 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { auto* builder = predictor_.config_.pass_builder(); builder->SetPasses({ - "infer_clean_graph_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass", + "cpu_quantize_pass", "cpu_quantize_squash_pass", }); if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); auto passes = builder->AllPasses(); predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_.SetAnalysisPasses( - {"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"}); + {"ir_graph_clean_pass", "ir_analysis_pass", "memory_optimize_pass", + "ir_graph_to_program_pass"}); predictor_.argument_.SetQuantVarScales(scales_); } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 239161bc9ef571767f752cb764597782ac368c59..f48c280087ef5e70cd47545d949ab6a3fe75e47e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -71,8 +71,7 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ - "infer_clean_graph_pass", // - "conv_affine_channel_fuse_pass", // + "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // @@ -91,7 +90,6 @@ const std::vector kTRTSubgraphPasses({ // The following passes works for Anakin sub-graph engine. const std::vector kAnakinSubgraphPasses({ - "infer_clean_graph_pass", // "quant_conv2d_dequant_fuse_pass", // "simplify_anakin_priorbox_detection_out_pass", // "fillconstant_elementwisemul_fuse", // @@ -105,9 +103,8 @@ const std::vector kAnakinSubgraphPasses({ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ - "infer_clean_graph_pass", // - // "identity_scale_op_clean_pass", // - "conv_affine_channel_fuse_pass", // + // "identity_scale_op_clean_pass", // + "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // @@ -141,8 +138,7 @@ void GpuPassStrategy::EnableNgraph() { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. - passes_.assign({"infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // + passes_.assign({"attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // // "seqpool_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", // diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 62b7ab30450f15aa8cb8e4a46bc37f70af851eb0..6aa59e0950cf9d612cedcf76cc09e95a8ae228c5 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -72,7 +72,7 @@ class PaddlePassBuilder { protected: std::vector analysis_passes_{ - {"ir_graph_build_pass", "ir_analysis_pass", + {"ir_graph_build_pass", "ir_graph_clean_pass", "ir_analysis_pass", "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass", "inference_op_replace_pass"}}; std::vector passes_; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index d4d872976c4d2a7bce331ed964a4050eec3d4619..cda6ef76e1d6ce74e5f2bae3d2faec318cf8acb4 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -51,8 +51,8 @@ void TensorRTEngine::FreezeNetwork() { // build engine. infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxWorkspaceSize(max_workspace_); -#if IS_TRT_VERSION_GE(5000) bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); +#if IS_TRT_VERSION_GE(5000) if (enable_fp16) { bool support_fp16 = infer_builder_->platformHasFastFp16(); infer_builder_->setFp16Mode(support_fp16); @@ -62,9 +62,10 @@ void TensorRTEngine::FreezeNetwork() { } } #else - LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT " - "is at least 5." - "So, use FP32 to run."; + if (enable_fp16) + LOG(INFO) << "Using FP16 in Paddle-trt must ensure that the version of TRT " + "is at least 5." + "So, use FP32 to run."; #endif bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8); diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 5edc233f6f73262c3d1b803aae0089f5b15d403d..6a9d8222c4435c470460fbf3564cdc8d668783ce 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -28,9 +28,9 @@ class MergeLoDTensorOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { + protected: + void RunBase(const framework::Scope &scope, + const platform::Place &dev_place) const { // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); @@ -125,6 +125,33 @@ class MergeLoDTensorOp : public framework::OperatorBase { out_lod->insert(out_lod->begin(), x.lod()[i]); } } + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + RunBase(scope, dev_place); + } +}; + +class MergeLoDTensorInferOp : public MergeLoDTensorOp { + public: + MergeLoDTensorInferOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : MergeLoDTensorOp(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + RunBase(scope, dev_place); + framework::Variable *in_true_var = scope.FindVar(Input("InTrue")); + framework::Variable *in_false_var = scope.FindVar(Input("InFalse")); + in_true_var->Clear(); + in_false_var->Clear(); + in_true_var->GetMutable(); + in_false_var->GetMutable(); + } }; class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -196,3 +223,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, ops::MergeLoDTensorOpProtoMaker, ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); +REGISTER_OPERATOR(merge_lod_tensor_infer, ops::MergeLoDTensorInferOp, + ops::MergeLoDTensorOpProtoMaker, + ops::MergeLoDTensorInferShape, + paddle::framework::EmptyGradOpMaker); diff --git a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py index 5397d5c52158ccfb9ad5703b957ca59d6fa11418..f407eb1d8b75cc3c29fc798c19d4284881dcdd49 100644 --- a/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_and_merge_lod_tensor_op.py @@ -23,6 +23,7 @@ from paddle.fluid.executor import Executor from paddle.fluid.backward import append_backward from paddle.fluid.layers.control_flow import split_lod_tensor from paddle.fluid.layers.control_flow import merge_lod_tensor +from paddle.fluid.layer_helper import LayerHelper class TestCPULoDTensorArrayOps(unittest.TestCase): @@ -57,7 +58,7 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): expect_false=expect_false, expect_out=tensor) - def test_split_and_merge_lod_tensor_level_0(self): + def split_and_merge_lod_tensor_level_0(self, use_merge_lod_infer=False): tensor = core.LoDTensor() tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) tensor.set_recursive_sequence_lengths([[3, 6, 1]]) @@ -87,10 +88,23 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): mask=mask, expect_true=expect_true, expect_false=expect_false, - expect_out=tensor) - - def main(self, tensor, mask, expect_true, expect_false, expect_out, - level=0): + expect_out=tensor, + use_merge_lod_infer=use_merge_lod_infer) + + def test_split_and_merge_lod_tensor_1(self): + self.split_and_merge_lod_tensor_level_0() + + def test_split_and_merge_lod_tensor_2(self): + self.split_and_merge_lod_tensor_level_0(True) + + def main(self, + tensor, + mask, + expect_true, + expect_false, + expect_out, + level=0, + use_merge_lod_infer=False): place = self.place() program = Program() with program_guard(program): @@ -103,11 +117,36 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): out_true, out_false = split_lod_tensor(input=x, mask=y, level=level) out_true.persistable = True out_false.persistable = True - - out = merge_lod_tensor( - in_true=out_true, in_false=out_false, mask=y, x=x, level=level) - - out.persistable = True + if use_merge_lod_infer: + input_dict = { + 'X': x, + 'Mask': mask, + 'InTrue': out_true, + 'InFalse': out_false, + 'level': level + } + helper = LayerHelper('merge_lod_tensor_infer') + out = helper.create_variable_for_type_inference( + dtype=out_true.dtype) + helper.append_op( + type='merge_lod_tensor_infer', + inputs={ + 'X': x, + 'Mask': y, + 'InTrue': out_true, + 'InFalse': out_false + }, + outputs={'Out': out}, + attrs={'level': level}) + out.persistable = True + else: + out = merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level) + out.persistable = True exe = Executor(place) scope = core.Scope() @@ -122,9 +161,9 @@ class TestCPULoDTensorArrayOps(unittest.TestCase): var_false = scope.find_var(out_false.name).get_tensor() var_out = scope.find_var(out.name).get_tensor() - - self.check_tensor_same(var_true, expect_true) - self.check_tensor_same(var_false, expect_false) + if not use_merge_lod_infer: + self.check_tensor_same(var_true, expect_true) + self.check_tensor_same(var_false, expect_false) self.check_tensor_same(var_out, expect_out) def check_tensor_same(self, actual, expect):