From e89b12884afa24a54a5a3d94754dd4cf461dc213 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 25 Sep 2019 10:28:38 +0800 Subject: [PATCH] FIx C++ inference BUG: When open memory optim and enable trt subgraph at the same time, there is a bug (#19969) * fix memory optimization type test=develop * 1. fix BUG: open trt and memory optim will trigger bug. 2. Clean memory optim bug. test=develop --- paddle/fluid/inference/analysis/argument.h | 4 +- .../ir_passes/tensorrt_subgraph_pass.cc | 32 +- .../analysis/passes/memory_optimize_pass.cc | 576 +----------------- .../analysis/passes/memory_optimize_pass.h | 69 +-- paddle/fluid/inference/api/analysis_config.cc | 10 +- .../fluid/inference/api/analysis_predictor.cc | 75 --- .../fluid/inference/api/analysis_predictor.h | 5 - .../inference/api/paddle_analysis_config.h | 5 +- .../tests/api/analyzer_dam_tester.cc | 27 - 9 files changed, 51 insertions(+), 752 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 1aceb4f469e..42858655aaa 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -196,9 +196,7 @@ struct Argument { // Memory optimized related. DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); - DECL_ARGUMENT_FIELD(static_memory_optim, StaticMemoryOptim, bool); - DECL_ARGUMENT_FIELD(static_memory_optim_force_update, - StaticMemoryOptimForceUpdate, bool); + // Indicate which kind of sort algorithm is used for operators, the memory // optimization relays on the sort algorithm. DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 8d696e448e2..bd2f79a12aa 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -41,7 +41,8 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( }; SubGraphFuser fuser(graph, teller, - Get("min_subgraph_size") /*min subgraph size*/); + Get("min_subgraph_size") /*min subgraph size*/, + "tensorrt_engine"); fuser(); std::vector graph_param_names = @@ -200,13 +201,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( "Ys", std::vector(output_names.begin(), output_names.end())); op_desc->SetBlockAttr("sub_block", new_block); - SetAttr(op_desc->Proto(), "subgraph", - block_desc.Proto()->SerializeAsString()); - SetAttr(op_desc->Proto(), "max_batch_size", Get("max_batch_size")); - SetAttr(op_desc->Proto(), "workspace_size", Get("workspace_size")); - SetAttr(op_desc->Proto(), "gpu_id", Get("gpu_device_id")); - SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); - SetAttr(op_desc->Proto(), "parameters", params); + op_desc->SetAttr("subgraph", block_desc.Proto()->SerializeAsString()); + op_desc->SetAttr("max_batch_size", Get("max_batch_size")); + op_desc->SetAttr("workspace_size", Get("workspace_size")); + op_desc->SetAttr("gpu_id", Get("gpu_device_id")); + op_desc->SetAttr("output_name_mapping", output_mapping); + op_desc->SetAttr("parameters", params); // we record all inputs' shapes in attr to check if they are consistent // with the real inputs' shapes retrieved from scope when trt runs. @@ -232,16 +232,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( calibration_data = GetTrtCalibTableData( Get("model_opt_cache_dir"), engine_key, enable_int8); } - SetAttr(op_desc->Proto(), "calibration_data", calibration_data); + op_desc->SetAttr("calibration_data", calibration_data); + op_desc->SetAttr("enable_int8", enable_int8); + op_desc->SetAttr("enable_fp16", enable_fp16); + op_desc->SetAttr("use_calib_mode", use_calib_mode); + op_desc->SetAttr("engine_key", engine_key); + op_desc->SetAttr("predictor_id", predictor_id); - SetAttr(op_desc->Proto(), "enable_int8", enable_int8); - SetAttr(op_desc->Proto(), "enable_fp16", enable_fp16); - SetAttr(op_desc->Proto(), "use_calib_mode", use_calib_mode); - SetAttr(op_desc->Proto(), "engine_key", engine_key); - SetAttr(op_desc->Proto(), "predictor_id", predictor_id); std::string trt_engine_serialized_data = ""; - SetAttr(op_desc->Proto(), "engine_serialized_data", - trt_engine_serialized_data); + op_desc->SetAttr("engine_serialized_data", trt_engine_serialized_data); + op_desc->Flush(); std::unique_ptr calibrator; if (enable_int8 && calibration_data.size() != 0) { diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 9a563467731..6fbf880356c 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -225,264 +225,6 @@ void MakeSimpleReusePlan( } } -// Collect the memory size of the tensors. -void MemoryOptimizePass::CollectVarMemorySize( - const std::unordered_map& batch_var_ave_dim, - std::unordered_map* tensor_nodes, - space_table_t* space_table) const { - // Collect tensors from graph. - for (auto* node : graph_->Nodes()) { - if (node->IsVar() && - node->Var()->GetType() == - framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { - // Parameters will not be reused. - if (node->Var()->Persistable()) continue; - (*tensor_nodes)[node->Name()] = node; - (*space_table)[node->Name()] = - DataTypeToSpace(node->Var()->GetDataType()) * - batch_var_ave_dim.at(node->Name()); - } - } -} - -// Find a sutable (big enough but smallest to avoid memory waste). -// -// Args: -// @tensor_nodes: the tensor nodes in the ir::Graph. -// @free_existing_tensors: the allocated tensor and are free. -// @space_table: the memory space of tensors. -// @tensor2use: the tensor that requires memory. -// -// Returns: -// true if found some existing tensor to reuse. -// false if no sutable tensor to reuse, one need to allocate a new tensor for -// this requirement. -// The suitable tensor for reuse is one that is approximately equal to the -// memory demand. -bool FindSuitableTensorToReuse( - const std::string& tensor, int space_required, - const std::unordered_map& tensor_nodes, - std::unordered_set* free_existing_tensors, - const space_table_t& space_table, - const std::vector>& var_clusters, - std::string* tensor2use) __SHOULD_USE_RESULT__; - -bool FindSuitableTensorToReuse( - const std::string& tensor, int space_required, - const std::unordered_map& tensor_nodes, - std::unordered_set* free_existing_tensors, - const space_table_t& space_table, - const std::vector>& var_clusters, - std::string* tensor2use) { - std::pair best_fit; - best_fit.second = std::numeric_limits::max(); - VLOG(5) << "Split Tensors to " << var_clusters.size() << " clusters"; - - // find the cluster this var belongs to. - const std::unordered_set* cluster = nullptr; - for (const auto& c : var_clusters) { - if (c.count(tensor)) { - cluster = &c; - break; - } - } - PADDLE_ENFORCE_NOT_NULL(cluster, - "something wrong in memory optimization, the " - "variable %s not in the clusters.", - tensor); - - for (auto& candidate : *free_existing_tensors) { - // This is not a temporary tensor. - if (!space_table.count(candidate)) continue; - // Not in the same cluster. - if (!cluster->count(candidate)) continue; - - size_t space = space_table.at(candidate); - PADDLE_ENFORCE( - space <= std::numeric_limits::type>::max(), - "space overload"); - size_t space_diff = - std::abs((std::make_signed::type)space - space_required); - if (space_diff < best_fit.second) { - best_fit.first = candidate; - best_fit.second = space_diff; - } - } - - if (best_fit.second < std::numeric_limits::max()) { - *tensor2use = best_fit.first; - return true; - } - return false; -} - -// Allocate new tensor instead of reusing the existing one. -void AllocateNewTensor( - const std::string& name, size_t space_required, - const std::unordered_map& tensor_nodes, - std::unordered_set* free_existing_tensors, - space_table_t* space_table, - std::unordered_map* reuse_table) { - // The newly born tensor is free to be used. - free_existing_tensors->insert(name); - // Register the space it has. - PADDLE_ENFORCE(space_table->count(name)); - space_table->at(name) = std::max(space_table->at(name), space_required); - // The allocated new tensor use the memory of itself. - (*reuse_table)[name] = name; -} - -// Free a tensor and make it resuable. -// @tensor: the tensor to free. -// @free_existing_tensors: the free and allocated tensors. -// @reuse_table: a map from a fake tensor to the existing allocated tensor. -void FreeATensor(const std::string& tensor, - std::unordered_set* free_existing_tensors, - std::unordered_map* reuse_table) { - if (tensor == "feed" || tensor == "fetch") return; - // the really allocated tensor. - const auto& free_tensor = reuse_table->at(tensor); - - free_existing_tensors->insert(free_tensor); -} - -// Reuse a free existing tensor. -void ReuseATensor(const std::string& tensor, const std::string& tensor2reuse, - size_t memory_size, - std::unordered_set* free_existing_tensors, - std::unordered_map* reuse_table, - space_table_t* reused_space_table) { - auto it = free_existing_tensors->find(tensor2reuse); - PADDLE_ENFORCE(it != free_existing_tensors->end()); - free_existing_tensors->erase(it); - (*reuse_table)[tensor] = tensor2reuse; - // Update the memory size of a reused tensor, the memory will grow if the - // required memory is larger. - (*reused_space_table)[tensor2reuse] = - std::max(reused_space_table->at(tensor2reuse), memory_size); -} - -// Calculate the memory usage. -void EvaluateMemoryUsage( - const std::unordered_map& reuse_table, - const space_table_t& space_table, - const std::unordered_map& var_batch_ave_size, - size_t* allocated, size_t* saved) { - *allocated = 0; - *saved = 0; - - for (auto elem : reuse_table) { - if (elem.first == elem.second) { - *allocated += space_table.at(elem.first); - VLOG(4) << elem.first << " <-> " << elem.second << " " - << space_table.at(elem.first) << " " - << space_table.at(elem.second); - } else { - *saved += space_table.at(elem.first); - VLOG(4) << "reuse " << elem.first << " -> " << elem.second; - } - } - VLOG(4) << "allocated " << *allocated; - VLOG(4) << "saved " << *saved; -} - -// Return saved ratio. -void MemoryOptimizePass::MakeReusePlan( - const std::vector>& var_clusters, - const std::unordered_map& var_batch_ave_size, - const space_table_t& space_table, - std::unordered_map* reuse_table, int sort_kind, - MemoryAllocation* memory_allocation) const { - // Clear the existing plan. - reuse_table->clear(); - - // The `space_table` stores the real memory size for each tensor. - // The `reused_space_table` stores the maximum memory size required by a - // tensor during the memory reusing, the small tensor might be reused by a - // larger tensor, and the memory size of the small one will grow. - auto reused_space_table = space_table; - - std::unordered_map life_cycles; - std::unordered_map tensor_nodes; - // The allocated tensors whose memory can be reused, they will live across the - // program execution. - std::unordered_set existing_tensors; - // The existing tensor that has been allocated, and is also free to reuse. - std::unordered_set free_existing_tensors; - - CollectLifeCycle(&life_cycles, sort_kind); - - for (int age = 0; age < max_lifecycle_; ++age) { - std::unordered_set born_tensors; - std::unordered_set dead_tensors; - // Gather the dead and born tensors. - for (auto elem_it = life_cycles.begin(); elem_it != life_cycles.end(); - elem_it++) { - if (elem_it->second.first == -1) { - continue; - } - const auto& tensor = elem_it->first; - const auto& lifecycle = elem_it->second; - VLOG(4) << "process " << tensor << " reuse " << lifecycle.first << "->" - << lifecycle.second; - - // Collect newly born tensors. - if (lifecycle.first == age) { - born_tensors.insert(tensor); - } - // Collect dead tensors whose memory can be reused. - else if (lifecycle.second < age) { // NOLINT - dead_tensors.insert(tensor); - // remove to avoid duplicate process. - elem_it->second.first = -1; // avoid duplicate search - } - } - - // Reuse the dead tensors for born tensors - for (const auto& tensor : born_tensors) { - // Skip the feed and fetch tensor for that they share data with others. - std::string tensor2reuse; - if (!space_table.count(tensor)) continue; - size_t space_required = space_table.at(tensor); - if (FindSuitableTensorToReuse(tensor, space_required, tensor_nodes, - &free_existing_tensors, reused_space_table, - var_clusters, &tensor2reuse)) { - if (tensor != tensor2reuse) { - VLOG(4) << tensor << " -> " << tensor2reuse; - } - ReuseATensor(tensor, tensor2reuse, space_required, - &free_existing_tensors, reuse_table, &reused_space_table); - } else { - VLOG(4) << "allocate " << tensor; - AllocateNewTensor(tensor, space_required, tensor_nodes, - &free_existing_tensors, &reused_space_table, - reuse_table); - ReuseATensor(tensor, tensor, space_required, &free_existing_tensors, - reuse_table, &reused_space_table); - } - } - - for (const auto& tensor : dead_tensors) { - // free its memory. - FreeATensor(tensor, &free_existing_tensors, reuse_table); - } - } - - EvaluateMemoryUsage(*reuse_table, reused_space_table, var_batch_ave_size, - &(memory_allocation->allocated), - &(memory_allocation->saved)); - memory_allocation->sort_kind = sort_kind; -} - -void BuildVarNodeTable(Graph* graph, - std::unordered_map* var_node_table) { - for (auto* node : graph->Nodes()) { - if (node->IsVar()) { - (*var_node_table)[node->Name()] = node; - } - } -} - // NOTE The optimized opdesc doesn't match ir::Graph. void UpdateOpDescsByReuse( Graph* graph, @@ -551,311 +293,35 @@ void UpdateOpDescsByReuse( } } -void MemoryOptimizePass::PerformReusePlan( - const std::unordered_map& reuse_table, - int sort_kind, std::unordered_set* vars2remove) const { - std::unordered_map var_node_table; - BuildVarNodeTable(graph_, &var_node_table); - UpdateOpDescsByReuse(graph_, reuse_table, sort_kind); - - for (auto& item : reuse_table) { - if (item.first != item.second) { - vars2remove->insert(item.first); - } - } - VLOG(2) << "to remove vars " << vars2remove->size(); -} - -std::vector split(const std::string& line, char delim) { - std::vector res; - std::string field; - std::stringstream line_stream(line); - while (std::getline(line_stream, field, delim)) { - res.emplace_back(field); - } - return res; -} - -// Deserialize the batch var shapes from the cache file. -std::vector>> DeseralizeBatchVarShapes( - const std::string& path) { - std::ifstream file(path); - PADDLE_ENFORCE(file.is_open(), "failed to open %s to read cache", path); - std::string line; - std::vector>> batch_shapes; - - while (std::getline(file, line)) { - std::map> batch; - for (const auto& var_info : split(line, ';')) { - auto fields = split(var_info, ':'); - PADDLE_ENFORCE_EQ(fields.size(), 2UL); - auto var_name = fields.front(); - auto shape_str = split(fields[1], ','); - std::vector shape; - for (const auto& v : shape_str) shape.push_back(std::stoi(v)); - batch[var_name] = shape; - } - batch_shapes.push_back(batch); - } - return batch_shapes; -} - -// Replace the -1 in shape to a real number to fake the shape. -std::vector>> FakeBatchVarShapes( - const framework::ProgramDesc& program) { - std::vector>> res; - res.emplace_back(); - auto& record = res.front(); - const int fake_batch_size = 3; - for (auto* var : program.Block(0).AllVars()) { - if (var->GetType() == - framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { - auto shape = var->GetShape(); - for (auto& v : shape) { - if (v < 0) v = fake_batch_size; - } - record[var->Name()].assign(shape.begin(), shape.end()); - } - } - return res; -} - -// Calculate the average dim of each tensor from the batch shape cache. -std::unordered_map GetBatchAverageSize( - const std::vector>>& batches) { - std::unordered_map var2size; - // The average size of the batches for each variable. - int num_batch = 0; - for (const auto& batch : batches) { - num_batch++; - for (const auto& item : batch) { - int dim = std::accumulate(item.second.begin(), item.second.end(), 1, - [](int a, int b) { return a * b; }); - var2size[item.first] += dim; - } - } - - for (auto& item : var2size) { - item.second /= num_batch; - } - - return var2size; -} - -// Analysis the batch shapes loading from the cache file. -// By splitting the variables to different clusters by analyzing their batch -// size, we can pre-schedule the changes of difference LoDTensor when different -// length of input sequences is entered. -// This should works fine for the models operating on sentences. -std::vector> AnalysisBatchShapesByBatchSize( - const std::vector>>& batches) { - // collect the batch size of each shape and combine to a stringstream in - // converient to generate a hash. - std::unordered_map var_batchsize_hashes; - for (auto& batch : batches) { - for (auto& ele : batch) { - PADDLE_ENFORCE(!ele.second.empty()); - int batch_size = ele.second.front(); - // TODO(Superjomn) might consume large memory here, use combine hash. - var_batchsize_hashes[ele.first] << batch_size; - } - } - - // Split to sets by batch size sequences. - std::unordered_map> - shape_sets; - for (auto& ele : var_batchsize_hashes) { - auto hash = std::hash()(ele.second.str()); - shape_sets[hash].insert(ele.first); - } - std::vector> res; - for (auto& ele : shape_sets) { - res.emplace_back(std::move(ele.second)); - } - - VLOG(3) << "Cluster by batch_size and get " << res.size() << " clusters"; - return res; -} - -// Analysis the batch shapes loading from the cache file, and split them to -// different clusters by their size. -// This should works fine for the overall models. -std::vector> AnalysisBatchShapesBySimilarSize( - const space_table_t& space_table, - const std::vector>>& batches, - int interval = 200000) { - PADDLE_ENFORCE_GT(interval, 0); - // cluster to different clusters. - size_t max_size = 0; - for (auto& item : space_table) { - max_size = std::max(item.second, max_size); - } - VLOG(4) << "tensor max size " << max_size; - - std::vector> res; - - // cluster by intervals. - for (size_t interval_size = 0; interval_size <= max_size; - interval_size += interval) { - std::unordered_set cluster; - for (auto& item : space_table) { - if (interval_size <= item.second && - interval_size + interval > item.second) { - cluster.insert(item.first); - } - } - if (!cluster.empty()) { - res.push_back(cluster); - } - } - - VLOG(3) << "Cluster by interval and get " << res.size() << " cluster"; - return res; -} - std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } -std::pair GetRange( - const std::unordered_map& ave_size) { - auto res = std::make_pair(std::numeric_limits::max(), - std::numeric_limits::min()); - for (auto& item : ave_size) { - res.first = std::min(item.second, res.first); - res.second = std::max(item.second, res.second); - } - return res; -} - void MemoryOptimizePass::RunImpl(Argument* argument) { - // When force update, should not optimize memory. - if (!argument->enable_memory_optim() || - argument->static_memory_optim_force_update()) - return; + // Memory optimization. + // We will perform the following operation: + // 1. Collect all var's lifetime. + // 2. Make reuse plan: the vars can be reused if there is no overlap(on + // lifetime) between + // them. + // The final plan is a mapping table in which the key represents the original + // name of var and the value in the table represents the current name of var. + // 3. Perform reuse plan: Replace all var's name in the model according to the + // mapping table. + if (!argument->enable_memory_optim()) return; graph_ = argument->main_graph_ptr(); - auto path = GetMemoryCachePath( - argument->model_dir_valid() ? argument->model_dir() : "", - argument->model_program_path_valid() ? argument->model_program_path() - : ""); - VLOG(3) << "Load memory cache from " << path; - std::vector>> batches; - - if (!(argument->static_memory_optim() && inference::IsFileExists(path))) { - string::PrettyLogInfo("--- Performing dynamic memory optimize"); - // batches = FakeBatchVarShapes(argument->main_program()); - int sort_kind = 0; - std::unordered_map lifecycles; - space_table_t space_table; - std::unordered_map node2cluster; - std::unordered_map cluster_size; - - CollectLifeCycle(&lifecycles, sort_kind); - CollectVarMemorySize(&space_table); - MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); - UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); - return; - - } else { - string::PrettyLogInfo("--- Performing static memory optimize"); - batches = DeseralizeBatchVarShapes(path); - } - auto var_batch_ave_size = GetBatchAverageSize(batches); - - // Get min and max memory size. - const auto range = GetRange(var_batch_ave_size); - const int cluster_size = std::max( - static_cast((range.second - range.first) / 100 /*cluster num*/), - 1024); - const int cluster_size1 = std::max( - static_cast((range.second - range.first) / 1000 /*cluster num*/), - 1024); - - std::unordered_map tensor_nodes; + int sort_kind = 0; + std::unordered_map lifecycles; space_table_t space_table; - CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table); - - std::unordered_map reuse_table; - double max_saving_ratio = 0.; - - std::vector> strategies; - - for (int sort_kind = 0; sort_kind < 2; sort_kind++) { - if (argument->static_memory_optim()) { - // This strategy only make scene in static memory optimize. - strategies.emplace_back([&, sort_kind] { - auto clustered_vars_by_batch_size = - AnalysisBatchShapesByBatchSize(batches); - MemoryAllocation allocation; - MakeReusePlan(clustered_vars_by_batch_size, var_batch_ave_size, - space_table, &reuse_table, sort_kind, &allocation); - return allocation; - }); - } - - strategies.emplace_back([&, sort_kind] { - auto clustered_vars_by_ave_size = - AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size); - MemoryAllocation allocation; - MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table, - &reuse_table, sort_kind, &allocation); - return allocation; - }); - - strategies.emplace_back([&, sort_kind] { - auto clustered_vars_by_ave_size = - AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size1); - MemoryAllocation allocation; - MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table, - &reuse_table, sort_kind, &allocation); - return allocation; - }); - - strategies.emplace_back([&, sort_kind] { - auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize( - space_table, batches, - std::numeric_limits::max()); // no intervals - MemoryAllocation allocation; - MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table, - &reuse_table, sort_kind, &allocation); - return allocation; - }); - } - - std::function* best_strategy{nullptr}; - - // Try all strategies to get the best result. - for (auto& strategy : strategies) { - auto allocation = strategy(); - string::PrettyLogDetail("--- get strategy saving %f memory for workspace", - allocation.GetSavingRatio()); - if (allocation.GetSavingRatio() > max_saving_ratio) { - max_saving_ratio = allocation.GetSavingRatio(); - best_strategy = &strategy; - } - } - if (!best_strategy) { - LOG(ERROR) << "This model makes poor memory optimize, skip memory optimize"; - return; - } - auto memory_allocation = (*best_strategy)(); - - string::PrettyLogInfo( - "--- Saved %.2f%s memory for workspace(temporary variables)", - memory_allocation.GetSavingRatio() * 100, "%"); - - argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove, - new std::unordered_set); - auto& vars2remove = - argument->main_graph().Get>( - framework::ir::kGraphToProgramVarsToRemove); - - PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove); - argument->SetMemoryOptimSortKind(memory_allocation.sort_kind); + std::unordered_map node2cluster; + std::unordered_map cluster_size; + + CollectLifeCycle(&lifecycles, sort_kind); + CollectVarMemorySize(&space_table); + MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); + UpdateOpDescsByReuse(graph_, node2cluster, sort_kind); + return; } -float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const { - return (saved / 1024.) / (allocated / 1024. + saved / 1024.); -} } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 90e285da099..77da5d40d8d 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -25,45 +25,22 @@ namespace paddle { namespace inference { namespace analysis { -/* - * Memory optimization pass for inference with pre-analysis of memory usage - * without GC. - * Different from training, the inference memory reuse strategies doesn't - * include GC for that overhead is too much when batch size equals one. - * - * The inference memory reuse tries to pre-determine the tensor reusing strategy - * without runtime overhead. - * - * To improve the strategy's performance, a warm-up running is introduced: - * - Before officially deploy the inference program, one should warm it up and - * generate some runtime cache, - * - Run the inference program with several batches of data, it will persist - * some runtime information about memory of tensors to disk, we call the - * information the memory reusing cache, - * - With the memory reusing cache, user can deploy the inference to a - * service, before running the model, the inference program will load the - * memory cache, analysis it and generate the best memory reusing strategy, - * and adjust the execution of the network. - * - * With the warm-up and memory reusing cache design, the memory reusing - * algorithm can analysis the real memory consume of the tensors, even with the - * flexible LoDTensor and special shape changing operators such as - * sequence-pooling. - */ +/* Memory optimization. +* We will perform the following operation: +* 1. Collect all var's lifetime. +* 2. Make reuse plan: the vars can be reused if there is no overlap(on lifetime) +* between +* them. +* The final plan is a mapping table in which the key represents the original +* name of var and the value in the table represents the current name of var. +* 3. Perform reuse plan: Replace all var's name in the model according to the +* mapping table. +*/ class MemoryOptimizePass : public AnalysisPass { public: using space_table_t = std::unordered_map; using lifecycle_t = std::pair; - struct MemoryAllocation { - size_t allocated; // allocated memory in byte. - size_t saved; // saved memory in byte. - int sort_kind; // the kind of the corresponding sorting algorithm. - - // Get the memory saving ratio of temporary variables. - float GetSavingRatio() const; - }; - virtual ~MemoryOptimizePass() = default; protected: @@ -75,24 +52,6 @@ class MemoryOptimizePass : public AnalysisPass { int sort_kind) const; void CollectVarMemorySize(space_table_t *space_table) const; - void CollectVarMemorySize0(space_table_t *space_table) const; - - void CollectVarMemorySize( - const std::unordered_map &batch_var_ave_dim, - std::unordered_map *tensor_nodes, - space_table_t *space_table) const; - - // Returns percentage of saved memory. - void MakeReusePlan( - const std::vector> &var_clusters, - const std::unordered_map &var_batch_ave_size, - const space_table_t &space_table, - std::unordered_map *reuse_table, int sort_kind, - MemoryAllocation *memory_allocation) const; - - void PerformReusePlan( - const std::unordered_map &reuse_table, - int sort_kind, std::unordered_set *vars2remove) const; public: std::string repr() const override; @@ -102,12 +61,6 @@ class MemoryOptimizePass : public AnalysisPass { mutable int max_lifecycle_{-1}; }; -static std::string GetMemoryCachePath(const std::string &model_path, - const std::string &prog_path) { - auto path = model_path.empty() ? prog_path : model_path; - return path + ".memory_cache"; -} - } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index c08a73d0da7..ace260c7cdb 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -101,8 +101,6 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(enable_memory_optim_); - CP_MEMBER(static_memory_optim_); - CP_MEMBER(static_memory_optim_force_update_); // TensorRT related. CP_MEMBER(use_tensorrt_); CP_MEMBER(tensorrt_workspace_size_); @@ -371,8 +369,6 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_min_subgraph_size_; ss << enable_memory_optim_; - ss << static_memory_optim_; - ss << static_memory_optim_force_update_; ss << use_ngraph_; @@ -420,12 +416,8 @@ float AnalysisConfig::fraction_of_gpu_memory_for_pool() const { #endif } -void AnalysisConfig::EnableMemoryOptim(bool static_optim, - bool force_update_static_cache) { +void AnalysisConfig::EnableMemoryOptim() { enable_memory_optim_ = true; - static_memory_optim_ = static_optim; - static_memory_optim_force_update_ = force_update_static_cache; - Update(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5cf1942cb27..d47bde32de6 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -241,11 +241,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } - // Collect variable shapes for memory optimization. - if (need_collect_var_shapes_for_memory_optim()) { - CollectVarShapes(); - } - VLOG(3) << "predict cost: " << timer.toc() << "ms"; // All the containers in the scope will be hold in inference, but the @@ -390,9 +385,6 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); - argument_.SetStaticMemoryOptim(config_.static_memory_optim_); - argument_.SetStaticMemoryOptimForceUpdate( - config_.static_memory_optim_force_update_); argument_.SetModelFromMemory(config_.model_from_memory_); // Analyze inference_program argument_.SetUseAnakin(config_.anakin_engine_enabled()); @@ -818,13 +810,6 @@ AnalysisPredictor::~AnalysisPredictor() { mkldnn_quantizer_ = nullptr; } #endif - - // TODO(Superjomn) deduce the directory path. - std::string out_path = inference::analysis::GetMemoryCachePath( - config_.model_dir(), config_.prog_file()); - if (need_collect_var_shapes_for_memory_optim()) { - SerializeBatchVarShapes(out_path); - } } std::unique_ptr AnalysisPredictor::Clone() { @@ -834,66 +819,6 @@ std::unique_ptr AnalysisPredictor::Clone() { return std::unique_ptr(x); } -void AnalysisPredictor::CollectVarShapes() { - VLOG(4) << "Collecting var shapes"; - if (batch_var_shapes_.size() >= max_shape_collect_count_) return; - std::map> var_shapes; - for (auto var_name : inference_program_->Block(0).LocalVarNames()) { - auto *var = sub_scope_->FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL(var); - if (var->Type() == framework::VarTypeTrait::kId || - var->Type() == framework::VarTypeTrait::kId) { - auto &tensor = var->Get(); - auto shape = framework::vectorize(tensor.dims()); - var_shapes[var_name].assign(shape.begin(), shape.end()); - } - } - batch_var_shapes_.push_back(var_shapes); - LOG_FIRST_N(INFO, 1) << "Collected " << batch_var_shapes_.size() - << " batch of var shapes for analysis"; -} - -void AnalysisPredictor::SerializeBatchVarShapes(const std::string &path) { - LOG(INFO) << "serialize batch var shapes to " << path; - std::ofstream file(path); - if (!file.is_open()) { - LOG(ERROR) << "failed to serialize the var shapes to " << path; - return; - } - - // The sirialized data format: - // :dim0,dim1,dim2,; - for (auto &batch : batch_var_shapes_) { - for (auto &ele : batch) { - file << ele.first << ":"; - for (size_t i = 0; i < ele.second.size() - 1; i++) { - file << ele.second[i] << ","; - } - file << ele.second.back() << ";"; - } - file << "\n"; - } -} - -bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() { - if (need_collect_var_shapes_ >= 0) return need_collect_var_shapes_; - bool need = false; - // check if the cache exists - if (!config_.enable_memory_optim()) { - need = false; - } else if (config_.static_memory_optim_ && - !inference::IsFileExists(inference::analysis::GetMemoryCachePath( - config_.model_dir(), config_.prog_file()))) { - need = true; - } else if (config_.static_memory_optim_ && - config_.static_memory_optim_force_update_) { - need = true; - } - - need_collect_var_shapes_ = need ? 1 : 0; - return need; -} - std::string AnalysisPredictor::GetSerializedProgram() const { return inference_program_->Proto()->SerializeAsString(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 2426e677490..33a2e62303a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -91,11 +91,6 @@ class AnalysisPredictor : public PaddlePredictor { void SaveOptimModel(const std::string &dir); protected: - // For memory optimization. - bool need_collect_var_shapes_for_memory_optim(); - void CollectVarShapes(); - void SerializeBatchVarShapes(const std::string &path); - bool PrepareProgram(const std::shared_ptr &program); bool PrepareScope(const std::shared_ptr &parent_scope); bool CreateExecutor(); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 4ab1ca9588c..7764a498695 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -244,8 +244,7 @@ struct AnalysisConfig { /** Turn on memory optimize * NOTE still in development, will release latter. */ - void EnableMemoryOptim(bool static_optim = false, - bool force_update_static_cache = false); + void EnableMemoryOptim(); /** Tell whether the memory optimization is activated. */ bool enable_memory_optim() const; @@ -309,8 +308,6 @@ struct AnalysisConfig { // memory reuse related. bool enable_memory_optim_{false}; - bool static_memory_optim_{false}; - bool static_memory_optim_force_update_{false}; bool use_ngraph_{false}; bool use_mkldnn_{false}; diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index 83bf99ec8aa..78c87b6db50 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -262,33 +262,6 @@ void compare(bool use_mkldnn = false) { reinterpret_cast(&cfg), input_slots_all); } -// Compare result of NativeConfig and AnalysisConfig with memory optimization. -TEST(Analyzer_dam, compare_with_static_memory_optim) { - // The small dam will core in CI, but works in local. - if (FLAGS_max_turn_num == 9) { - AnalysisConfig cfg, cfg1; - DataRecord data(FLAGS_infer_data, FLAGS_batch_size); - - std::vector> input_slots_all; - SetInput(&input_slots_all); - // Run the first time to force to update memory cache - SetConfig(&cfg); - cfg.EnableMemoryOptim(true, true /*force update*/); - - CompareNativeAndAnalysis( - reinterpret_cast(&cfg), - input_slots_all); - - // Run second time to use the memory cache and perform memory optimization. - SetConfig(&cfg1); - cfg1.EnableMemoryOptim(true, false /*do not force update*/); - - CompareNativeAndAnalysis( - reinterpret_cast(&cfg1), - input_slots_all); - } -} - TEST(Analyzer_dam, compare_with_dynamic_memory_optim) { // The small dam will core in CI, but works in local. if (FLAGS_max_turn_num == 9) { -- GitLab