From bc9fd1fc10f5ff3a2310cd1e04b40309ce508cf7 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Mon, 8 Jul 2019 20:13:24 +0800 Subject: [PATCH] CHERRY-Pick: Inference: fix mask rcnn model diff, optim memory usage, memory leak. #18532 (#18547) fix mask rcnn add interface for setting optim_cache_dir(eg: when in trt int8 mode, and load model from memory, there should be a interface for setting the trt calibration table data dir) test=release/1.5 --- .../framework/ir/graph_pattern_detector.cc | 12 ++ .../framework/ir/graph_pattern_detector.h | 3 + paddle/fluid/inference/analysis/argument.h | 3 +- .../inference/analysis/ir_pass_manager.cc | 21 ++-- .../ir_passes/tensorrt_subgraph_pass.cc | 1 + .../inference/analysis/passes/CMakeLists.txt | 2 + .../passes/inference_op_replace_pass.cc | 47 ++++++++ .../passes/inference_op_replace_pass.h | 43 +++++++ .../analysis/passes/memory_optimize_pass.cc | 31 ++++- .../fluid/inference/analysis/passes/passes.cc | 3 + paddle/fluid/inference/api/analysis_config.cc | 6 +- .../fluid/inference/api/analysis_predictor.cc | 2 +- .../inference/api/paddle_analysis_config.h | 12 +- .../fluid/inference/api/paddle_pass_builder.h | 4 +- .../controlflow/conditional_block_infer_op.cc | 74 ++++++++++++ .../controlflow/conditional_block_op.cc | 91 +------------- .../controlflow/conditional_block_op.h | 111 ++++++++++++++++++ .../operators/tensorrt/tensorrt_engine_op.h | 5 +- paddle/fluid/pybind/inference_api.cc | 2 +- 19 files changed, 354 insertions(+), 119 deletions(-) create mode 100644 paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc create mode 100644 paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h create mode 100644 paddle/fluid/operators/controlflow/conditional_block_infer_op.cc create mode 100644 paddle/fluid/operators/controlflow/conditional_block_op.h diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 15b3429ef..1e05275f6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -504,6 +504,16 @@ PDNode *PDNode::assert_op_has_n_outputs(const std::string &op_type, size_t n) { return this; } +PDNode *PDNode::assert_has_n_inputs(size_t n) { + asserts_.emplace_back([=](Node *x) { return x->inputs.size() == n; }); + return this; +} + +PDNode *PDNode::assert_has_n_outputs(size_t n) { + asserts_.emplace_back([=](Node *x) { return x->outputs.size() == n; }); + return this; +} + PDNode *PDNode::assert_more(PDNode::teller_t &&teller) { asserts_.emplace_back(std::move(teller)); return this; @@ -1444,11 +1454,13 @@ PDNode *patterns::ConvAffineChannel::operator()( auto *ac_scale_var = pattern->NewNode(ac_scale_repr()) ->AsInput() ->assert_is_persistable_var() + ->assert_has_n_outputs(1) ->assert_is_op_input("affine_channel", "Scale"); // AC Bias auto *ac_bias_var = pattern->NewNode(ac_bias_repr()) ->AsInput() ->assert_is_persistable_var() + ->assert_has_n_outputs(1) ->assert_is_op_input("affine_channel", "Bias"); // AC output diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1c53b9105..0d8f9d5b0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -131,6 +131,9 @@ struct PDNode { const std::unordered_set& op_types, const std::string& argument, int nth); + PDNode* assert_has_n_inputs(size_t n); + PDNode* assert_has_n_outputs(size_t n); + template PDNode* assert_op_attr(const std::string& attr_name, const T& attr) { asserts_.emplace_back([=](Node* x) { diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 7bcd1f01b..e468bc226 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -59,7 +59,6 @@ struct Argument { using unique_ptr_t = std::unique_ptr>; using fusion_statis_t = std::unordered_map; - using engine_opt_info_t = std::map; using anakin_max_shape_t = std::map>; bool Has(const std::string& key) const { return valid_fields_.count(key); } @@ -130,7 +129,7 @@ struct Argument { DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); - DECL_ARGUMENT_FIELD(engine_opt_info, EngineOptInfo, engine_opt_info_t); + DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); // The overall graph to work on. DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index f290e6fce..2dae51371 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -94,11 +94,20 @@ void IRPassManager::CreatePasses(Argument *argument, bool use_static_engine = argument->tensorrt_use_static_engine(); bool model_from_memory = argument->model_from_memory(); - bool int8_valid = !(model_from_memory && enable_int8); + std::string optim_cache_dir = argument->optim_cache_dir(); + bool int8_valid = + !(model_from_memory && optim_cache_dir.empty() && enable_int8); PADDLE_ENFORCE(int8_valid, - "TRT INT8 Now don't support model load from memory."); - - if ((!model_from_memory && use_static_engine) || enable_int8) { + "When you are in TRT INT8 mode, and load model from " + "memory, you should set optim_cache_dir using " + "config.SetOptimCacheDir()"); + PADDLE_ENFORCE(!(model_from_memory && use_static_engine), + "When you are using Paddle-TRT, and also using load model " + "from memory, you should set the use_static to false."); + + if (!optim_cache_dir.empty()) { + pass->Set("model_opt_cache_dir", new std::string(optim_cache_dir)); + } else if (use_static_engine || enable_int8) { std::string model_opt_cache_dir = argument->Has("model_dir") ? argument->model_dir() @@ -110,8 +119,6 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("use_static_engine", new bool(use_static_engine)); pass->Set("model_from_memory", new bool(argument->model_from_memory())); - pass->Set("engine_opt_info", new std::map( - argument->engine_opt_info())); } if (pass_name == "ngraph_subgraph_pass") { pass->Set("program", @@ -123,8 +130,6 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("model_from_memory", new bool(argument->model_from_memory())); - pass->Set("engine_opt_info", new std::map( - argument->engine_opt_info())); pass->Set("predictor_id", new int(argument->predictor_id())); pass->Set("max_input_shape", new std::map>( argument->anakin_max_input_shape())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 37c3fc795..ce8f57c0f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -226,6 +226,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::unique_ptr calibrator; if (enable_int8 && calibration_data.size() != 0) { calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data)); + LOG(INFO) << "RUN Paddle TRT int8 calibration mode..."; } // When in int8 mode and calibration_mode, the program just produce the // calibration table data. diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index a8d0c69a5..860dc3097 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(memory_optim_pass SRCS memory_optimize_pass.cc DEPS analysis_pass zer cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass) +cc_library(inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass @@ -11,6 +12,7 @@ cc_library(analysis_passes SRCS passes.cc DEPS ir_params_sync_among_devices_pass adjust_cudnn_workspace_size_pass memory_optim_pass + inference_op_replace_pass ir_graph_to_program_pass ) diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc new file mode 100644 index 000000000..ef7d13da8 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" +#include + +namespace paddle { +namespace inference { +namespace analysis { + +void InferenceOpReplacePass::RunImpl(Argument* argument) { + if (!argument->use_gpu()) return; + std::unordered_map replaced_map{ + {"conditional_block", "conditional_block_infer"}, + }; + + auto& graph = argument->main_graph(); + auto nodes = graph.Nodes(); + + for (auto& node : nodes) { + if (!node->IsOp()) continue; + auto* op_desc = node->Op(); + std::string op_type = op_desc->Type(); + if (!replaced_map.count(op_type)) continue; + op_desc->SetType(replaced_map[op_type]); + op_desc->Flush(); + } +} + +std::string InferenceOpReplacePass::repr() const { + return "inference-op-replace-pass"; +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h new file mode 100644 index 000000000..7fbdd88e0 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/analysis/analysis_pass.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * There are some ops (while, conditional_block_op etc) which have different + * optimization points under predicion and training conditions. + * So, We added the corresponding inference impl to these ops separately. + * This pass replaces these ops with corresponding inference ops. + */ +class InferenceOpReplacePass : public AnalysisPass { + public: + void RunImpl(Argument *argument) override; + std::string repr() const override; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 1f4077eec..c894acfd4 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -108,11 +109,34 @@ int DataTypeToSpace(framework::proto::VarType_Type type) { void MemoryOptimizePass::CollectVarMemorySize( space_table_t* space_table) const { const int fake_batch_size = 1; + auto valid_var = [&](framework::ir::Node* node) -> bool { + std::set invalid_op = {"while", "conditional_block", + "tensorrt_engine", + "conditional_block_infer"}; + for (auto* tmp : node->inputs) { + CHECK(tmp->IsOp()); + std::string op_type = tmp->Op()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + for (auto* tmp : node->outputs) { + CHECK(tmp->IsOp()); + std::string op_type = tmp->Op()->Type(); + if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != + invalid_op.end()) { + return false; + } + } + return true; + }; // Collect tensors from graph. for (auto* node : graph_->Nodes()) { if (node->IsVar() && node->Var()->GetType() == - framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { + framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && + valid_var(node)) { // Parameters will not be reused. if (node->Var()->Persistable()) continue; auto shape = node->Var()->GetShape(); @@ -135,12 +159,9 @@ void MakeSimpleReusePlan( std::unordered_map* cluster_size) { std::vector mem_nodes; for (auto& data : lifecycles) { + if (!space_table.count(data.first)) continue; MemNode temp_node; temp_node.name = data.first; - PADDLE_ENFORCE( - space_table.count(data.first), - "%s variable should be in the spacetable during memory optimize", - data.first); temp_node.size = space_table.at(data.first); temp_node.cluster = -1; temp_node.lifetime = data.second; diff --git a/paddle/fluid/inference/analysis/passes/passes.cc b/paddle/fluid/inference/analysis/passes/passes.cc index a55904ed5..97debcec5 100644 --- a/paddle/fluid/inference/analysis/passes/passes.cc +++ b/paddle/fluid/inference/analysis/passes/passes.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/passes/passes.h" #include "paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h" +#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" @@ -38,6 +39,8 @@ PassRegistry::PassRegistry() { std::unique_ptr(new IrParamsSyncAmongDevicesPass)); passes_.emplace("adjust_cudnn_workspace_size_pass", std::unique_ptr(new AdjustCudnnWorkSpacePass)); + passes_.emplace("inference_op_replace_pass", + std::unique_ptr(new InferenceOpReplacePass)); passes_.emplace( "ir_graph_to_program_pass", std::unique_ptr(new IrGraphToProgramPass)); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 890c90697..4d0bf7746 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -90,6 +90,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(model_from_memory_); // the memory model reuses prog_file_ and // params_file_ fields. + CP_MEMBER(opt_cache_dir_); prog_file_ = std::move(other.prog_file_); params_file_ = std::move(other.params_file_); @@ -406,11 +407,6 @@ void AnalysisConfig::SetModelBuffer(const char *prog_buffer, Update(); } -void AnalysisConfig::SetEngineOptInfo( - std::map engine_opt_info) { - engine_opt_info_ = engine_opt_info; -} - NativeConfig AnalysisConfig::ToNativeConfig() const { NativeConfig config; config.model_dir = model_dir_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5d9d5a317..e7a8549d3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -368,10 +368,10 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetStaticMemoryOptimForceUpdate( config_.static_memory_optim_force_update_); argument_.SetModelFromMemory(config_.model_from_memory_); - argument_.SetEngineOptInfo(config_.engine_opt_info_); // Analyze inference_program argument_.SetUseAnakin(config_.anakin_engine_enabled()); argument_.SetPredictorID(predictor_id_); + argument_.SetOptimCacheDir(config_.opt_cache_dir_); if (!config_.model_dir().empty()) { argument_.SetModelDir(config_.model_dir()); } else { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index e3682d270..e94ca5e96 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -61,6 +61,11 @@ struct AnalysisConfig { /** Set parameter composed file path. */ void SetParamsFile(const std::string& x) { params_file_ = x; } + /** Set opt cache dir. + */ + void SetOptimCacheDir(const std::string& opt_cache_dir) { + opt_cache_dir_ = opt_cache_dir; + } /** Get the model directory path. */ const std::string& model_dir() const { return model_dir_; } @@ -143,7 +148,7 @@ struct AnalysisConfig { int max_batch_size = 1, int min_subgraph_size = 3, Precision precision = Precision::kFloat32, bool use_static = false, - bool use_calib_mode = false); + bool use_calib_mode = true); /** A boolean state telling whether the TensorRT engine is used. */ bool tensorrt_engine_enabled() const { return use_tensorrt_; } @@ -223,7 +228,6 @@ struct AnalysisConfig { /** A boolean state telling whether the model is set from the CPU memory. */ bool model_from_memory() const { return model_from_memory_; } - void SetEngineOptInfo(std::map engine_opt_info); /** Turn on memory optimize * NOTE still in development, will release latter. @@ -311,15 +315,15 @@ struct AnalysisConfig { bool anakin_auto_config_layout_{false}; std::vector anakin_passes_filter_; std::vector anakin_ops_filter_; - std::map engine_opt_info_; bool use_mkldnn_quantizer_{false}; std::shared_ptr mkldnn_quantizer_config_; // If the config is already used on a predictor, it becomes invalid. - mutable bool is_valid_{true}; // Any config can only be used with one predictor. // Variables held by config can take up a lot of memory in some cases. // So we release the memory when the predictor is set up. + mutable bool is_valid_{true}; + std::string opt_cache_dir_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 4236399aa..62b7ab304 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -73,8 +73,8 @@ class PaddlePassBuilder { protected: std::vector analysis_passes_{ {"ir_graph_build_pass", "ir_analysis_pass", - "ir_params_sync_among_devices_pass", - "adjust_cudnn_workspace_size_pass"}}; + "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass", + "inference_op_replace_pass"}}; std::vector passes_; }; diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc new file mode 100644 index 000000000..8ad2f7938 --- /dev/null +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/controlflow/conditional_block_op.h" + +namespace paddle { +namespace operators { + +/* We will implement the op with block separately in the future. + * The main reason is that some of the training requirements + * in these OPS can lead to problems(such as memory leaks) during inference. + */ +class ConditionalBlockInferOp : public ConditionalOp { + public: + ConditionalBlockInferOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ConditionalOp(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + bool need_run; + if (Attr("is_scalar_condition")) { + // When is_scalar_condition is True, the conditional variable is a scalar, + // whether need to execute the operators in sub-block depends on the + // conditional variable (Cond). + auto xs = InputTensors(scope, "Cond"); + need_run = ScalarCondition(xs); + } else { + // When is_scalar_condition is False, the conditional variable maybe a + // vector or tensor, whether need to execute the operators in sub-block + // depends on the input variables (Input). + auto xs = InputTensors(scope, "Input"); + need_run = std::all_of( + xs.begin(), xs.end(), + [](const framework::LoDTensor *t) { return t->numel() != 0; }); + } + + if (need_run) { + auto *scope_var = scope.FindVar(Output("Scope")); + PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + auto *scopes = scope_var->GetMutable>(); + scopes->resize(1); + scopes->front() = &scope.NewScope(); + auto &cur_scope = *scopes->front(); + + framework::Executor exec(dev_place); + auto *block = Attr("sub_block"); + exec.Run(*block->Program(), &cur_scope, block->ID(), false); + scope.DeleteScope(scopes->front()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(conditional_block_infer, ops::ConditionalBlockInferOp, + ops::ConditionalBlockOpProtoMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index f0dc71819..8358ef755 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -11,67 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/var_type.h" + +#include "paddle/fluid/operators/controlflow/conditional_block_op.h" namespace paddle { namespace operators { -class ConditionalOp : public framework::OperatorBase { - public: - ConditionalOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - protected: - std::vector InputTensors( - const framework::Scope &scope, const std::string &in_name) const { - std::vector retv; - auto xs = Inputs(in_name); - retv.resize(xs.size(), nullptr); - std::transform( - xs.begin(), xs.end(), retv.begin(), - [&scope](const std::string &var_name) -> const framework::LoDTensor * { - auto *var = scope.FindVar(var_name); - PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name); - return &var->Get(); - }); - return retv; - } - - bool ScalarCondition( - const std::vector &ips) const { - if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { - PADDLE_THROW("should have one initialized input as condition"); - } - - PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL && - ips[0]->numel() == 1, - "condition input's data type should be bool, " - "numel should be 1, actual numel is %d", - ips[0]->numel()); - bool res = false; - if (platform::is_gpu_place(ips[0]->place())) { -#ifdef PADDLE_WITH_CUDA - framework::LoDTensor cpu_tensor; - framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor); - platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait(); - res = cpu_tensor.data()[0]; -#endif - } else { - res = ips[0]->data()[0]; - } - return res; - } -}; - class ConditionalBlockOp : public ConditionalOp { public: ConditionalBlockOp(const std::string &type, @@ -115,38 +60,6 @@ class ConditionalBlockOp : public ConditionalOp { } }; -class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Cond", - "The conditional variable of this operator. If Cond is empty, the " - "whole sub-block will not be executed.") - .AsDuplicable(); - AddInput("Input", "The input variables of the sub-block.").AsDuplicable(); - AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); - AddOutput("Scope", - "(std::vector) The step scope of conditional block. To " - "unify the conditional block, rnn and while op, the type of " - "scope is std::vector"); - AddAttr( - "sub_block", "The step block of conditional block operator"); - AddAttr("is_scalar_condition", - "The conditional variable (Cond) is used as scalar " - "condition.") - .SetDefault(false); - AddComment(R"DOC(Conditional block operator - -If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar, -run the operators in sub-block if Cond is True. - -If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or -tensor, run the operators in sub-block if all of input variables are not empty. - - -)DOC"); - } -}; - class ConditionalBlockGradOp : public ConditionalOp { public: ConditionalBlockGradOp(const std::string &type, diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.h b/paddle/fluid/operators/controlflow/conditional_block_op.h new file mode 100644 index 000000000..9a079c845 --- /dev/null +++ b/paddle/fluid/operators/controlflow/conditional_block_op.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" + +namespace paddle { +namespace operators { + +class ConditionalOp : public framework::OperatorBase { + public: + ConditionalOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + protected: + std::vector InputTensors( + const framework::Scope &scope, const std::string &in_name) const { + std::vector retv; + auto xs = Inputs(in_name); + retv.resize(xs.size(), nullptr); + std::transform( + xs.begin(), xs.end(), retv.begin(), + [&scope](const std::string &var_name) -> const framework::LoDTensor * { + auto *var = scope.FindVar(var_name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name); + return &var->Get(); + }); + return retv; + } + + bool ScalarCondition( + const std::vector &ips) const { + if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { + PADDLE_THROW("should have one initialized input as condition"); + } + + PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL && + ips[0]->numel() == 1, + "condition input's data type should be bool, " + "numel should be 1, actual numel is %d", + ips[0]->numel()); + bool res = false; + if (platform::is_gpu_place(ips[0]->place())) { +#ifdef PADDLE_WITH_CUDA + framework::LoDTensor cpu_tensor; + framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor); + platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait(); + res = cpu_tensor.data()[0]; +#endif + } else { + res = ips[0]->data()[0]; + } + return res; + } +}; + +class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Cond", + "The conditional variable of this operator. If Cond is empty, the " + "whole sub-block will not be executed.") + .AsDuplicable(); + AddInput("Input", "The input variables of the sub-block.").AsDuplicable(); + AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); + AddOutput("Scope", + "(std::vector) The step scope of conditional block. To " + "unify the conditional block, rnn and while op, the type of " + "scope is std::vector"); + AddAttr( + "sub_block", "The step block of conditional block operator"); + AddAttr("is_scalar_condition", + "The conditional variable (Cond) is used as scalar " + "condition.") + .SetDefault(false); + AddComment(R"DOC(Conditional block operator + +If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar, +run the operators in sub-block if Cond is True. + +If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or +tensor, run the operators in sub-block if all of input variables are not empty. + + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 21cf15cb0..79c9f759a 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -121,8 +121,9 @@ class TensorRTEngineOp : public framework::OperatorBase { // This process will builds a 32-bit trt engine, runs it on the calibration // set, and records a histogram for each // tensor of the distribution of activation values. - LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_ - << " is running calibration trt int8... "; + LOG_FIRST_N(INFO, 1) << "This process is generating calibration table for " + "Paddle TRT int8..."; + int runtime_batch = 1; if (!Singleton::Global().Has(engine_key_)) { TRTCalibratorEngine *calib_res = diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 27f0e30d0..d8664425b 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -237,7 +237,7 @@ void BindAnalysisConfig(py::module *m) { py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1, py::arg("min_subgraph_size") = 3, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, - py::arg("use_static") = true, py::arg("use_calib_mode") = false) + py::arg("use_static") = false, py::arg("use_calib_mode") = true) .def("enable_anakin_engine", &AnalysisConfig::EnableAnakinEngine, py::arg("max_batch_size") = 1, py::arg("max_input_shape") = -- GitLab