未验证 提交 0b36655b 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Memory Optimize destruct argument (#49046)

上级 40f3f4f0
......@@ -1085,101 +1085,103 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
}
void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu());
argument_.SetUseFcPadding(config_.use_fc_padding());
argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableIrOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetModelFromMemory(config_.model_from_memory_);
// Init std::unique_ptr argument_.
argument_.reset(new Argument);
argument_->SetUseGPU(config_.use_gpu());
argument_->SetUseFcPadding(config_.use_fc_padding());
argument_->SetGPUDeviceId(config_.gpu_device_id());
argument_->SetEnableIrOptim(config_.enable_ir_optim_);
argument_->SetEnableMemoryOptim(config_.enable_memory_optim());
argument_->SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program
argument_.SetPredictorID(predictor_id_);
argument_.SetRootPredictorID(root_predictor_id_);
argument_.SetOptimCacheDir(config_.opt_cache_dir_);
argument_->SetPredictorID(predictor_id_);
argument_->SetRootPredictorID(root_predictor_id_);
argument_->SetOptimCacheDir(config_.opt_cache_dir_);
if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir());
argument_->SetModelDir(config_.model_dir());
} else {
PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
false,
platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set."));
argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file());
argument_->SetModelProgramPath(config_.prog_file());
argument_->SetModelParamsPath(config_.params_file());
}
// For JITLayer
argument_.SetSkipLoadParams(config_.skip_load_params_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
argument_.SetTensorRtTunedDynamicShape(
argument_->SetSkipLoadParams(config_.skip_load_params_);
argument_->SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_->SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_->SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_->SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
argument_->SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
argument_->SetMinInputShape(config_.min_input_shape_);
argument_->SetMaxInputShape(config_.max_input_shape_);
argument_->SetOptimInputShape(config_.optim_input_shape_);
argument_->SetTensorRtTunedDynamicShape(
config_.tuned_tensorrt_dynamic_shape());
if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
LOG(INFO) << "TensorRT subgraph engine is enabled";
argument_.SetUseTensorRT(true);
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtUseDLA(config_.trt_use_dla_);
argument_.SetTensorRtDLACore(config_.trt_dla_core_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_.SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_.SetTensorRtAllowBuildAtRuntime(
argument_->SetUseTensorRT(true);
argument_->SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_->SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_->SetTensorRtUseDLA(config_.trt_use_dla_);
argument_->SetTensorRtDLACore(config_.trt_dla_core_);
argument_->SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_->SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_->SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_->SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_->SetTensorRtAllowBuildAtRuntime(
config_.trt_allow_build_at_runtime());
argument_.SetTensorRtUseInspector(config_.trt_use_inspector_);
argument_.SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing());
argument_->SetTensorRtUseInspector(config_.trt_use_inspector_);
argument_->SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing());
}
if (config_.dlnne_enabled()) {
LOG(INFO) << "Dlnne subgraph is enabled";
argument_.SetUseDlnne(true);
argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_);
argument_.SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_);
argument_.SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_);
argument_.SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_);
argument_.SetDlnneDisableNodesByOutputs(
argument_->SetUseDlnne(true);
argument_->SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_);
argument_->SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_);
argument_->SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_);
argument_->SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_);
argument_->SetDlnneDisableNodesByOutputs(
config_.dlnne_disable_nodes_by_outputs_);
argument_.SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_);
argument_.SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_);
argument_.SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
argument_->SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_);
argument_->SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_);
argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
}
if (config_.lite_engine_enabled()) {
argument_.SetCpuMathLibraryNumThreads(
argument_->SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads());
argument_.SetLitePrecisionMode(config_.lite_precision_mode_);
argument_.SetLitePassesFilter(config_.lite_passes_filter_);
argument_.SetLiteOpsFilter(config_.lite_ops_filter_);
argument_.SetLiteZeroCopy(config_.lite_zero_copy_);
argument_.SetUseXpu(config_.use_xpu_);
argument_.SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
argument_.SetXpuLocked(config_.xpu_locked_);
argument_.SetXpuAutotune(config_.xpu_autotune_);
argument_.SetXpuAutotuneFile(config_.xpu_autotune_file_);
argument_.SetXpuPrecision(config_.xpu_precision_);
argument_.SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_.SetXpuDeviceId(config_.xpu_device_id_);
argument_.SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_.SetUseOpenCL(config_.use_opencl_);
argument_->SetLitePrecisionMode(config_.lite_precision_mode_);
argument_->SetLitePassesFilter(config_.lite_passes_filter_);
argument_->SetLiteOpsFilter(config_.lite_ops_filter_);
argument_->SetLiteZeroCopy(config_.lite_zero_copy_);
argument_->SetUseXpu(config_.use_xpu_);
argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
argument_->SetXpuLocked(config_.xpu_locked_);
argument_->SetXpuAutotune(config_.xpu_autotune_);
argument_->SetXpuAutotuneFile(config_.xpu_autotune_file_);
argument_->SetXpuPrecision(config_.xpu_precision_);
argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_->SetXpuDeviceId(config_.xpu_device_id_);
argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_->SetUseOpenCL(config_.use_opencl_);
// NNAdapter related
argument_.SetUseNNAdapter(config_.NNAdapter().use_nnadapter);
argument_.SetNNAdapterDeviceNames(
argument_->SetUseNNAdapter(config_.NNAdapter().use_nnadapter);
argument_->SetNNAdapterDeviceNames(
config_.NNAdapter().nnadapter_device_names);
argument_.SetNNAdapterContextProperties(
argument_->SetNNAdapterContextProperties(
config_.NNAdapter().nnadapter_context_properties);
argument_.SetNNAdapterModelCacheDir(
argument_->SetNNAdapterModelCacheDir(
config_.NNAdapter().nnadapter_model_cache_dir);
argument_.SetNNAdapterSubgraphPartitionConfigBuffer(
argument_->SetNNAdapterSubgraphPartitionConfigBuffer(
config_.NNAdapter().nnadapter_subgraph_partition_config_buffer);
argument_.SetNNAdapterSubgraphPartitionConfigPath(
argument_->SetNNAdapterSubgraphPartitionConfigPath(
config_.NNAdapter().nnadapter_subgraph_partition_config_path);
std::vector<std::string> buffer_keys;
std::vector<std::vector<char>> buffer_vals;
......@@ -1187,67 +1189,67 @@ void AnalysisPredictor::PrepareArgument() {
buffer_keys.emplace_back(it.first);
buffer_vals.emplace_back(it.second);
}
argument_.SetNNAdapterModelCacheToken(buffer_keys);
argument_.SetNNAdapterModelCacheBuffer(buffer_vals);
argument_->SetNNAdapterModelCacheToken(buffer_keys);
argument_->SetNNAdapterModelCacheBuffer(buffer_vals);
LOG(INFO) << "Lite subgraph engine is enabled";
}
#ifdef PADDLE_WITH_IPU
argument_.SetUseIpu(config_.use_ipu_);
argument_.SetIpuDeviceNum(config_.ipu_device_num());
argument_.SetIpuMicroBatchSize(config_.ipu_micro_batch_size_);
argument_.SetIpuEnablePipelining(config_.ipu_enable_pipelining_);
argument_.SetIpuBatchesPerStep(config_.ipu_batches_per_step_);
argument_.SetIpuEnableFp16(config_.ipu_enable_fp16_);
argument_.SetIpuReplicaNum(config_.ipu_replica_num_);
argument_.SetIpuAvailableMemoryProportion(
argument_->SetUseIpu(config_.use_ipu_);
argument_->SetIpuDeviceNum(config_.ipu_device_num());
argument_->SetIpuMicroBatchSize(config_.ipu_micro_batch_size_);
argument_->SetIpuEnablePipelining(config_.ipu_enable_pipelining_);
argument_->SetIpuBatchesPerStep(config_.ipu_batches_per_step_);
argument_->SetIpuEnableFp16(config_.ipu_enable_fp16_);
argument_->SetIpuReplicaNum(config_.ipu_replica_num_);
argument_->SetIpuAvailableMemoryProportion(
config_.ipu_available_memory_proportion_);
argument_.SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_);
argument_.SetIpuEnableModelRuntimeExecutor(
argument_->SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_);
argument_->SetIpuEnableModelRuntimeExecutor(
config_.ipu_enable_model_runtime_executor_);
argument_.SetIpuCustomOpsInfo(config_.ipu_custom_ops_info_);
argument_.SetIpuCustomPatterns(config_.ipu_custom_patterns_);
argument_->SetIpuCustomOpsInfo(config_.ipu_custom_ops_info_);
argument_->SetIpuCustomPatterns(config_.ipu_custom_patterns_);
#endif
argument_.SetUseNpu(config_.use_npu_);
argument_.SetNPUDeviceId(config_.npu_device_id());
argument_->SetUseNpu(config_.use_npu_);
argument_->SetNPUDeviceId(config_.npu_device_id());
if (config_.use_mkldnn_) {
LOG(INFO) << "MKLDNN is enabled";
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
argument_->SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
}
if (config_.use_cinn_compiler_) {
argument_.SetUseCinnCompiler(config_.use_cinn_compiler_);
argument_->SetUseCinnCompiler(config_.use_cinn_compiler_);
}
#ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled";
argument_.SetQuantizeEnabledOpTypes(
argument_->SetQuantizeEnabledOpTypes(
config_.mkldnn_quantizer_config()->enabled_op_types());
argument_.SetQuantizeExcludedOpIds(
argument_->SetQuantizeExcludedOpIds(
config_.mkldnn_quantizer_config()->excluded_op_ids());
}
if (config_.use_mkldnn_bfloat16_) {
LOG(INFO) << "Bfloat16 is enabled";
argument_.SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
}
if (config_.use_mkldnn_int8_) {
LOG(INFO) << "Int8 is enabled";
argument_.SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_.SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
argument_.SetQuantVarScales({});
argument_->SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_->SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
argument_->SetQuantVarScales({});
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
argument_.SetUseCustomDevice(config_.use_custom_device());
argument_->SetUseCustomDevice(config_.use_custom_device());
if (config_.use_custom_device()) {
LOG(INFO) << "CustomDevice is enabled";
argument_.SetCustomDeviceType(config_.custom_device_type());
argument_.SetCustomDeviceId(config_.custom_device_id());
argument_->SetCustomDeviceType(config_.custom_device_type());
argument_->SetCustomDeviceId(config_.custom_device_id());
}
#endif
......@@ -1276,9 +1278,9 @@ void AnalysisPredictor::PrepareArgument() {
}
if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false);
argument_->SetEnableIrOptim(false);
if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true);
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
......@@ -1295,16 +1297,16 @@ void AnalysisPredictor::PrepareArgument() {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
}
argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
argument_.SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get());
argument_->SetDisableLogs(config_.glog_info_disabled());
argument_->SetIrAnalysisPasses(pass_builder->AllPasses());
argument_->SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_->SetScopeNotOwned(scope_.get());
// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>(
argument_->SetModelPrecision(static_cast<int>(model_precision_));
argument_->SetMixedBlackList(config_.mixed_black_list_);
argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_->SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
}
......@@ -1321,16 +1323,16 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
#endif
Analyzer().Run(&argument_);
Analyzer().Run(argument_.get());
PADDLE_ENFORCE_EQ(
argument_.scope_valid(),
argument_->scope_valid(),
true,
platform::errors::InvalidArgument("The argument scope should be valid."));
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
ARGUMENT_CHECK_FIELD((argument_.get()), ir_analyzed_program);
inference_program_.reset(
new framework::ProgramDesc(argument_.ir_analyzed_program()),
new framework::ProgramDesc(argument_->ir_analyzed_program()),
[](framework::ProgramDesc *prog) {
// Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads.
......@@ -1358,8 +1360,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
});
// The config and argument take a lot of storage,
// when the predictor settings are complete, we release these stores.
argument_.PartiallyRelease();
config_.PartiallyRelease();
fusion_statis_ = *argument_->fusion_statis_ptr();
#if defined(_WIN32)
argument_->PartiallyRelease();
#else
if (config_.mkldnn_enabled() || config_.tensorrt_engine_enabled()) {
argument_->PartiallyRelease();
} else {
argument_.reset(nullptr);
}
#endif
LOG(INFO) << "======= optimize end =======";
}
......@@ -2082,9 +2093,9 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
}
std::string model_opt_cache_dir =
argument_.Has("model_dir")
? argument_.model_dir()
: inference::analysis::GetDirRoot(argument_.model_program_path());
argument_->Has("model_dir") ? argument_->model_dir()
: inference::analysis::GetDirRoot(
argument_->model_program_path());
std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath(
......
......@@ -249,7 +249,7 @@ class AnalysisPredictor : public PaddlePredictor {
///
/// \return the argument obtained by config
///
Argument &analysis_argument() { return argument_; }
Argument &analysis_argument() { return *argument_; }
///
/// \brief Clone to get the new predictor. thread safe.
///
......@@ -276,6 +276,13 @@ class AnalysisPredictor : public PaddlePredictor {
///
std::string GetSerializedProgram() const override;
///
/// \brief Get the fusion_statis_t
///
/// \return the fusion_statis_t
///
Argument::fusion_statis_t fusion_statis() { return fusion_statis_; }
///
/// \brief Register a output hook function to operate the intermediate tensor
/// of op output. when using this function, memory reuse should be tured off.
......@@ -484,7 +491,8 @@ class AnalysisPredictor : public PaddlePredictor {
private:
AnalysisConfig config_;
Argument argument_;
std::unique_ptr<Argument> argument_;
Argument::fusion_statis_t fusion_statis_;
std::unique_ptr<NaiveExecutor> executor_;
platform::Place place_;
std::shared_ptr<framework::Scope> scope_;
......
......@@ -588,15 +588,15 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto& arg = predictor_.argument_;
if (!arg.scope_valid()) arg.SetScope(new framework::Scope);
arg.SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program()));
arg.SetMainGraph(graph.release());
auto* scope_ptr = arg.scope_ptr();
if (!arg->scope_valid()) arg->SetScope(new framework::Scope);
arg->SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg->main_program()));
arg->SetMainGraph(graph.release());
auto* scope_ptr = arg->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(
scope_ptr,
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
arg->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({"cpu_quantize_pass",
......@@ -605,10 +605,10 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
"params_quantization_mkldnn_pass"});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses(
predictor_.argument_->SetIrAnalysisPasses(passes);
predictor_.argument_->SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_);
predictor_.argument_->SetQuantVarScales(scales_);
}
bool AnalysisPredictor::MkldnnQuantizer::Quantize() {
......@@ -628,15 +628,15 @@ bool AnalysisPredictor::MkldnnQuantizer::RunQuantizePasses() const {
*predictor_.inference_program_, 0, true, predictor_.sub_scope_);
PrepareArgument();
auto& arg = predictor_.argument_;
Analyzer().Run(&arg);
Analyzer().Run(arg.get());
PADDLE_ENFORCE_EQ(
arg.scope_valid(),
arg->scope_valid(),
true,
platform::errors::PreconditionNotMet("The scope should be valid."));
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&arg), ir_analyzed_program);
ARGUMENT_CHECK_FIELD(arg.get(), ir_analyzed_program);
predictor_.inference_program_.reset(
new framework::ProgramDesc(arg.ir_analyzed_program()));
new framework::ProgramDesc(arg->ir_analyzed_program()));
LOG(INFO) << "== optimize 2 end ==";
predictor_.executor_->CreateVariables(
*predictor_.inference_program_, 0, false, predictor_.sub_scope_);
......
......@@ -78,10 +78,8 @@ TEST(Analyzer_Ernie, fuse_statis) {
LOG(INFO) << "num_ops: " << num_ops;
if (FLAGS_ernie_large) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 146);
EXPECT_EQ(num_ops, 859);
} else {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 74);
EXPECT_EQ(num_ops, 295);
}
}
......
......@@ -178,7 +178,6 @@ TEST(Analyzer_LAC, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 4);
EXPECT_EQ(num_ops, 11);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -161,7 +161,6 @@ TEST(Analyzer_Chinese_ner, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 2);
EXPECT_EQ(num_ops, 14);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -257,8 +257,6 @@ TEST(Analyzer_rnn1, fuse_statis) {
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -161,7 +161,6 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 31);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -40,7 +40,6 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 181);
}
} // namespace seq_pool1_tester
......
......@@ -372,23 +372,15 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
int *num_ops) {
std::unordered_map<std::string, int> res;
auto *analysis_predictor = static_cast<AnalysisPredictor *>(predictor);
auto *fusion_status =
analysis_predictor->analysis_argument().fusion_statis_ptr();
if (!fusion_status) {
return res;
auto fusion_status = analysis_predictor->fusion_statis();
if (fusion_status.empty()) {
fusion_status = res;
}
for (auto &item : *fusion_status) {
for (auto &item : fusion_status) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
int num = 0;
for (auto &node :
analysis_predictor->analysis_argument().main_graph().Nodes()) {
if (node->IsOp()) {
++num;
}
}
*num_ops = num;
return *fusion_status;
*num_ops = 0;
return fusion_status;
}
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册