未验证 提交 e2818c86 编写于 作者: Y Yan Chunwei 提交者: GitHub

add dynamic memory optim (#15457)

上级 88bd7e1a
...@@ -133,7 +133,9 @@ struct Argument { ...@@ -133,7 +133,9 @@ struct Argument {
// Memory optimized related. // Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
DECL_ARGUMENT_FIELD(memory_optim_force_update, MemoryOptimForceUpdate, bool); DECL_ARGUMENT_FIELD(static_memory_optim, StaticMemoryOptim, bool);
DECL_ARGUMENT_FIELD(static_memory_optim_force_update,
StaticMemoryOptimForceUpdate, bool);
// Indicate which kind of sort algorithm is used for operators, the memory // Indicate which kind of sort algorithm is used for operators, the memory
// optimization relays on the sort algorithm. // optimization relays on the sort algorithm.
DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int); DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);
......
...@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes( ...@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes(
return batch_shapes; return batch_shapes;
} }
// Replace the -1 in shape to a real number to fake the shape.
std::vector<std::map<std::string, std::vector<int>>> FakeBatchVarShapes(
const framework::ProgramDesc& program) {
std::vector<std::map<std::string, std::vector<int>>> res;
res.emplace_back();
auto& record = res.front();
const int fake_batch_size = 3;
for (auto* var : program.Block(0).AllVars()) {
if (var->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
auto shape = var->GetShape();
for (auto& v : shape) {
if (v < 0) v = fake_batch_size;
}
record[var->Name()].assign(shape.begin(), shape.end());
}
}
return res;
}
// Calculate the average dim of each tensor from the batch shape cache. // Calculate the average dim of each tensor from the batch shape cache.
std::unordered_map<std::string, size_t> GetBatchAverageSize( std::unordered_map<std::string, size_t> GetBatchAverageSize(
const std::vector<std::map<std::string, std::vector<int>>>& batches) { const std::vector<std::map<std::string, std::vector<int>>>& batches) {
...@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize( ...@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize(
std::unordered_map<std::string, std::stringstream> var_batchsize_hashes; std::unordered_map<std::string, std::stringstream> var_batchsize_hashes;
for (auto& batch : batches) { for (auto& batch : batches) {
for (auto& ele : batch) { for (auto& ele : batch) {
PADDLE_ENFORCE(!ele.second.empty());
int batch_size = ele.second.front(); int batch_size = ele.second.front();
// TODO(Superjomn) might consume large memory here, use combine hash. // TODO(Superjomn) might consume large memory here, use combine hash.
var_batchsize_hashes[ele.first] << batch_size; var_batchsize_hashes[ele.first] << batch_size;
...@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize( ...@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize(
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; }
std::pair<size_t, size_t> GetRange(
const std::unordered_map<std::string, size_t>& ave_size) {
auto res = std::make_pair(std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::min());
for (auto& item : ave_size) {
res.first = std::min(item.second, res.first);
res.second = std::max(item.second, res.second);
}
return res;
}
void MemoryOptimizePass::RunImpl(Argument* argument) { void MemoryOptimizePass::RunImpl(Argument* argument) {
// When force update, should not optimize memory. // When force update, should not optimize memory.
if (!argument->enable_memory_optim() || argument->memory_optim_force_update()) if (!argument->enable_memory_optim() ||
argument->static_memory_optim_force_update())
return; return;
graph_ = argument->main_graph_ptr(); graph_ = argument->main_graph_ptr();
...@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
argument->model_program_path_valid() ? argument->model_program_path() argument->model_program_path_valid() ? argument->model_program_path()
: ""); : "");
VLOG(3) << "Load memory cache from " << path; VLOG(3) << "Load memory cache from " << path;
if (inference::IsFileExists(path)) { std::vector<std::map<std::string, std::vector<int>>> batches;
VLOG(4) << "Performing memory optimize";
auto batches = DeseralizeBatchVarShapes(path); if (argument->static_memory_optim() && inference::IsFileExists(path)) {
string::PrettyLogInfo("--- Performing static memory optimize");
batches = DeseralizeBatchVarShapes(path);
} else {
string::PrettyLogInfo("--- Performing dynamic memory optimize");
batches = FakeBatchVarShapes(argument->main_program());
}
auto var_batch_ave_size = GetBatchAverageSize(batches); auto var_batch_ave_size = GetBatchAverageSize(batches);
// Get min and max memory size.
const auto range = GetRange(var_batch_ave_size);
const int cluster_size = std::max(
static_cast<int>((range.second - range.first) / 100 /*cluster num*/),
1024);
const int cluster_size1 = std::max(
static_cast<int>((range.second - range.first) / 1000 /*cluster num*/),
1024);
std::unordered_map<std::string, Node*> tensor_nodes; std::unordered_map<std::string, Node*> tensor_nodes;
space_table_t space_table; space_table_t space_table;
CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table); CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table);
...@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
std::vector<std::function<MemoryAllocation()>> strategies; std::vector<std::function<MemoryAllocation()>> strategies;
for (int sort_kind = 0; sort_kind < 2; sort_kind++) { for (int sort_kind = 0; sort_kind < 2; sort_kind++) {
if (argument->static_memory_optim()) {
// This strategy only make scene in static memory optimize.
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_batch_size = auto clustered_vars_by_batch_size =
AnalysisBatchShapesByBatchSize(batches); AnalysisBatchShapesByBatchSize(batches);
...@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table, &reuse_table, sort_kind, &allocation); space_table, &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
}
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize( auto clustered_vars_by_ave_size =
space_table, batches, 1024); // interval 1kb AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size);
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize( auto clustered_vars_by_ave_size =
space_table, batches, 1024 * 1024); // interval 1MB AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size1);
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
...@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table, batches, space_table, batches,
std::numeric_limits<int>::max()); // no intervals std::numeric_limits<int>::max()); // no intervals
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
} }
...@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
} }
} }
if (!best_strategy) { if (!best_strategy) {
LOG(ERROR) LOG(ERROR) << "This model makes poor memory optimize, skip memory optimize";
<< "This model makes poor memory optimize, skip memory optimize";
return; return;
} }
auto memory_allocation = (*best_strategy)(); auto memory_allocation = (*best_strategy)();
string::PrettyLogH2( string::PrettyLogInfo(
"--- Saved %.2f%s memory for workspace(temporary variables)", "--- Saved %.2f%s memory for workspace(temporary variables)",
memory_allocation.GetSavingRatio() * 100, "%"); memory_allocation.GetSavingRatio() * 100, "%");
string::PrettyLogDetail("--- Allocated %d MB",
memory_allocation.allocated / 1024. / 1024.);
string::PrettyLogDetail("--- Saved %d MB",
memory_allocation.saved / 1024. / 1024.);
argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove, argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove,
new std::unordered_set<std::string>); new std::unordered_set<std::string>);
auto& vars2remove = auto& vars2remove =
...@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove); PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove);
argument->SetMemoryOptimSortKind(memory_allocation.sort_kind); argument->SetMemoryOptimSortKind(memory_allocation.sort_kind);
}
} }
float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const { float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(memory_pool_init_size_mb_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
CP_MEMBER(memory_optim_force_update_); CP_MEMBER(static_memory_optim_);
CP_MEMBER(static_memory_optim_force_update_);
// TensorRT releated. // TensorRT releated.
CP_MEMBER(use_tensorrt_); CP_MEMBER(use_tensorrt_);
CP_MEMBER(tensorrt_workspace_size_); CP_MEMBER(tensorrt_workspace_size_);
...@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() { ...@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
ss << tensorrt_min_subgraph_size_; ss << tensorrt_min_subgraph_size_;
ss << enable_memory_optim_; ss << enable_memory_optim_;
ss << memory_optim_force_update_; ss << static_memory_optim_;
ss << static_memory_optim_force_update_;
ss << use_mkldnn_; ss << use_mkldnn_;
for (auto &item : mkldnn_enabled_op_types_) ss << item; for (auto &item : mkldnn_enabled_op_types_) ss << item;
...@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const { ...@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
#endif #endif
} }
void contrib::AnalysisConfig::EnableMemoryOptim(bool force_update_cache) { void contrib::AnalysisConfig::EnableMemoryOptim(
bool static_optim, bool force_update_static_cache) {
enable_memory_optim_ = true; enable_memory_optim_ = true;
memory_optim_force_update_ = force_update_cache; static_memory_optim_ = static_optim;
static_memory_optim_force_update_ = force_update_static_cache;
Update(); Update();
} }
...@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer, ...@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
Update(); Update();
} }
NativeConfig contrib::AnalysisConfig::ToNativeConfig() const {
NativeConfig config;
config.model_dir = model_dir_;
config.prog_file = prog_file_;
config.param_file = params_file_;
config.use_gpu = use_gpu_;
config.device = device_id_;
config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
config.specify_input_name = specify_input_name_;
return config;
}
} // namespace paddle } // namespace paddle
...@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::get_fetch"; VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs_.size()); outputs->resize(fetches_.size());
for (size_t i = 0; i < fetchs_.size(); ++i) { for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = boost::get<int>(fetchs_[i]->GetAttr("col")); int idx = boost::get<int>(fetches_[i]->GetAttr("col"));
PADDLE_ENFORCE((size_t)idx == i); PADDLE_ENFORCE((size_t)idx == i);
framework::LoDTensor &fetch = framework::LoDTensor &fetch =
framework::GetFetchVariable(*scope, "fetch", idx); framework::GetFetchVariable(*scope, "fetch", idx);
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetches_[idx]->Input("X")[0];
if (type == framework::proto::VarType::FP32) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
...@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetMemoryOptimForceUpdate(config_.memory_optim_force_update_); argument_.SetStaticMemoryOptim(config_.static_memory_optim_);
argument_.SetStaticMemoryOptimForceUpdate(
config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
...@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() { ...@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() {
feed_names_[op->Output("Out")[0]] = idx; feed_names_[op->Output("Out")[0]] = idx;
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
int idx = boost::get<int>(op->GetAttr("col")); int idx = boost::get<int>(op->GetAttr("col"));
if (fetchs_.size() <= static_cast<size_t>(idx)) { if (fetches_.size() <= static_cast<size_t>(idx)) {
fetchs_.resize(idx + 1); fetches_.resize(idx + 1);
} }
fetchs_[idx] = op; fetches_[idx] = op;
} }
} }
} }
...@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() { ...@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
// check if the cache exists // check if the cache exists
if (!config_.enable_memory_optim()) { if (!config_.enable_memory_optim()) {
need = false; need = false;
} else if (config_.enable_memory_optim() && } else if (config_.static_memory_optim_ &&
!inference::IsFileExists(inference::analysis::GetMemoryCachePath( !inference::IsFileExists(inference::analysis::GetMemoryCachePath(
config_.model_dir(), config_.prog_file()))) { config_.model_dir(), config_.prog_file()))) {
need = true; need = true;
} else if (config_.enable_memory_optim() && } else if (config_.static_memory_optim_ &&
config_.memory_optim_force_update_) { config_.static_memory_optim_force_update_) {
need = true; need = true;
} }
......
...@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor {
std::shared_ptr<framework::ProgramDesc> inference_program_; std::shared_ptr<framework::ProgramDesc> inference_program_;
std::vector<framework::OpDesc *> feeds_; std::vector<framework::OpDesc *> feeds_;
std::map<std::string, size_t> feed_names_; std::map<std::string, size_t> feed_names_;
std::vector<framework::OpDesc *> fetchs_; std::vector<framework::OpDesc *> fetches_;
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious // Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, wrong results and memory leak, so cache them. // concurrency problems, wrong results and memory leak, so cache them.
std::vector<framework::LoDTensor> feed_tensors_; std::vector<framework::LoDTensor> feed_tensors_;
......
...@@ -162,17 +162,7 @@ struct AnalysisConfig { ...@@ -162,17 +162,7 @@ struct AnalysisConfig {
/** Transform the AnalysisConfig to NativeConfig. /** Transform the AnalysisConfig to NativeConfig.
*/ */
NativeConfig ToNativeConfig() const { NativeConfig ToNativeConfig() const;
NativeConfig config;
config.model_dir = model_dir_;
config.prog_file = prog_file_;
config.param_file = params_file_;
config.use_gpu = use_gpu_;
config.device = device_id_;
config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
config.specify_input_name = specify_input_name_;
return config;
}
/** Specify the operator type list to use MKLDNN acceleration. /** Specify the operator type list to use MKLDNN acceleration.
* @param op_list the operator type list. * @param op_list the operator type list.
*/ */
...@@ -195,7 +185,8 @@ struct AnalysisConfig { ...@@ -195,7 +185,8 @@ struct AnalysisConfig {
/** Turn on memory optimize /** Turn on memory optimize
* NOTE still in development, will release latter. * NOTE still in development, will release latter.
*/ */
void EnableMemoryOptim(bool force_update_cache = false); void EnableMemoryOptim(bool static_optim = false,
bool force_update_static_cache = false);
/** Tell whether the memory optimization is activated. */ /** Tell whether the memory optimization is activated. */
bool enable_memory_optim() const; bool enable_memory_optim() const;
...@@ -241,7 +232,8 @@ struct AnalysisConfig { ...@@ -241,7 +232,8 @@ struct AnalysisConfig {
// memory reuse related. // memory reuse related.
bool enable_memory_optim_{false}; bool enable_memory_optim_{false};
bool memory_optim_force_update_{false}; bool static_memory_optim_{false};
bool static_memory_optim_force_update_{false};
bool use_mkldnn_{false}; bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_; std::unordered_set<std::string> mkldnn_enabled_op_types_;
......
...@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) { ...@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) {
} }
// Compare result of NativeConfig and AnalysisConfig with memory optimization. // Compare result of NativeConfig and AnalysisConfig with memory optimization.
TEST(Analyzer_dam, compare_with_memory_optim) { TEST(Analyzer_dam, compare_with_static_memory_optim) {
// The small dam will core in CI, but works in local. // The small dam will core in CI, but works in local.
if (FLAGS_max_turn_num == 9) { if (FLAGS_max_turn_num == 9) {
contrib::AnalysisConfig cfg, cfg1; contrib::AnalysisConfig cfg, cfg1;
...@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
SetInput(&input_slots_all); SetInput(&input_slots_all);
// Run the first time to force to update memory cache // Run the first time to force to update memory cache
SetConfig(&cfg); SetConfig(&cfg);
cfg.EnableMemoryOptim(true); cfg.EnableMemoryOptim(true, true /*force update*/);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
...@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
// Run second time to use the memory cache and perform memory optimization. // Run second time to use the memory cache and perform memory optimization.
SetConfig(&cfg1); SetConfig(&cfg1);
cfg1.EnableMemoryOptim(); cfg1.EnableMemoryOptim(true, false /*do not force update*/);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg1), reinterpret_cast<const PaddlePredictor::Config *>(&cfg1),
...@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
} }
} }
TEST(Analyzer_dam, compare_with_dynamic_memory_optim) {
// The small dam will core in CI, but works in local.
if (FLAGS_max_turn_num == 9) {
contrib::AnalysisConfig cfg, cfg1;
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
// Run the first time to force to update memory cache
SetConfig(&cfg);
cfg.EnableMemoryOptim();
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
}
TEST(Analyzer_dam, compare) { compare(); } TEST(Analyzer_dam, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册