未验证 提交 0b36655b 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Memory Optimize destruct argument (#49046)

上级 40f3f4f0
...@@ -1085,101 +1085,103 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -1085,101 +1085,103 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} }
void AnalysisPredictor::PrepareArgument() { void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu()); // Init std::unique_ptr argument_.
argument_.SetUseFcPadding(config_.use_fc_padding()); argument_.reset(new Argument);
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_->SetUseGPU(config_.use_gpu());
argument_.SetEnableIrOptim(config_.enable_ir_optim_); argument_->SetUseFcPadding(config_.use_fc_padding());
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_->SetGPUDeviceId(config_.gpu_device_id());
argument_.SetModelFromMemory(config_.model_from_memory_); argument_->SetEnableIrOptim(config_.enable_ir_optim_);
argument_->SetEnableMemoryOptim(config_.enable_memory_optim());
argument_->SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
argument_.SetPredictorID(predictor_id_); argument_->SetPredictorID(predictor_id_);
argument_.SetRootPredictorID(root_predictor_id_); argument_->SetRootPredictorID(root_predictor_id_);
argument_.SetOptimCacheDir(config_.opt_cache_dir_); argument_->SetOptimCacheDir(config_.opt_cache_dir_);
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_->SetModelDir(config_.model_dir());
} else { } else {
PADDLE_ENFORCE_EQ(config_.prog_file().empty(), PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
false, false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set.")); "Either model_dir or prog_file should be set."));
argument_.SetModelProgramPath(config_.prog_file()); argument_->SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file()); argument_->SetModelParamsPath(config_.params_file());
} }
// For JITLayer // For JITLayer
argument_.SetSkipLoadParams(config_.skip_load_params_); argument_->SetSkipLoadParams(config_.skip_load_params_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_->SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_); argument_->SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_); argument_->SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_); argument_->SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_); argument_->SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
argument_.SetMinInputShape(config_.min_input_shape_); argument_->SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_); argument_->SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_); argument_->SetOptimInputShape(config_.optim_input_shape_);
argument_.SetTensorRtTunedDynamicShape( argument_->SetTensorRtTunedDynamicShape(
config_.tuned_tensorrt_dynamic_shape()); config_.tuned_tensorrt_dynamic_shape());
if (config_.use_gpu() && config_.tensorrt_engine_enabled()) { if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
LOG(INFO) << "TensorRT subgraph engine is enabled"; LOG(INFO) << "TensorRT subgraph engine is enabled";
argument_.SetUseTensorRT(true); argument_->SetUseTensorRT(true);
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_->SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_->SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtUseDLA(config_.trt_use_dla_); argument_->SetTensorRtUseDLA(config_.trt_use_dla_);
argument_.SetTensorRtDLACore(config_.trt_dla_core_); argument_->SetTensorRtDLACore(config_.trt_dla_core_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_->SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); argument_->SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_); argument_->SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_.SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path()); argument_->SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_.SetTensorRtAllowBuildAtRuntime( argument_->SetTensorRtAllowBuildAtRuntime(
config_.trt_allow_build_at_runtime()); config_.trt_allow_build_at_runtime());
argument_.SetTensorRtUseInspector(config_.trt_use_inspector_); argument_->SetTensorRtUseInspector(config_.trt_use_inspector_);
argument_.SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing()); argument_->SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing());
} }
if (config_.dlnne_enabled()) { if (config_.dlnne_enabled()) {
LOG(INFO) << "Dlnne subgraph is enabled"; LOG(INFO) << "Dlnne subgraph is enabled";
argument_.SetUseDlnne(true); argument_->SetUseDlnne(true);
argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_); argument_->SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_);
argument_.SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_); argument_->SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_);
argument_.SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_); argument_->SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_);
argument_.SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_); argument_->SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_);
argument_.SetDlnneDisableNodesByOutputs( argument_->SetDlnneDisableNodesByOutputs(
config_.dlnne_disable_nodes_by_outputs_); config_.dlnne_disable_nodes_by_outputs_);
argument_.SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_); argument_->SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_);
argument_.SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_); argument_->SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_);
argument_.SetDlnnePrecisionMode(config_.dlnne_precision_mode_); argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
} }
if (config_.lite_engine_enabled()) { if (config_.lite_engine_enabled()) {
argument_.SetCpuMathLibraryNumThreads( argument_->SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads()); config_.cpu_math_library_num_threads());
argument_.SetLitePrecisionMode(config_.lite_precision_mode_); argument_->SetLitePrecisionMode(config_.lite_precision_mode_);
argument_.SetLitePassesFilter(config_.lite_passes_filter_); argument_->SetLitePassesFilter(config_.lite_passes_filter_);
argument_.SetLiteOpsFilter(config_.lite_ops_filter_); argument_->SetLiteOpsFilter(config_.lite_ops_filter_);
argument_.SetLiteZeroCopy(config_.lite_zero_copy_); argument_->SetLiteZeroCopy(config_.lite_zero_copy_);
argument_.SetUseXpu(config_.use_xpu_); argument_->SetUseXpu(config_.use_xpu_);
argument_.SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_); argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
argument_.SetXpuLocked(config_.xpu_locked_); argument_->SetXpuLocked(config_.xpu_locked_);
argument_.SetXpuAutotune(config_.xpu_autotune_); argument_->SetXpuAutotune(config_.xpu_autotune_);
argument_.SetXpuAutotuneFile(config_.xpu_autotune_file_); argument_->SetXpuAutotuneFile(config_.xpu_autotune_file_);
argument_.SetXpuPrecision(config_.xpu_precision_); argument_->SetXpuPrecision(config_.xpu_precision_);
argument_.SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_); argument_->SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
argument_.SetXpuDeviceId(config_.xpu_device_id_); argument_->SetXpuDeviceId(config_.xpu_device_id_);
argument_.SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_); argument_->SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
argument_.SetUseOpenCL(config_.use_opencl_); argument_->SetUseOpenCL(config_.use_opencl_);
// NNAdapter related // NNAdapter related
argument_.SetUseNNAdapter(config_.NNAdapter().use_nnadapter); argument_->SetUseNNAdapter(config_.NNAdapter().use_nnadapter);
argument_.SetNNAdapterDeviceNames( argument_->SetNNAdapterDeviceNames(
config_.NNAdapter().nnadapter_device_names); config_.NNAdapter().nnadapter_device_names);
argument_.SetNNAdapterContextProperties( argument_->SetNNAdapterContextProperties(
config_.NNAdapter().nnadapter_context_properties); config_.NNAdapter().nnadapter_context_properties);
argument_.SetNNAdapterModelCacheDir( argument_->SetNNAdapterModelCacheDir(
config_.NNAdapter().nnadapter_model_cache_dir); config_.NNAdapter().nnadapter_model_cache_dir);
argument_.SetNNAdapterSubgraphPartitionConfigBuffer( argument_->SetNNAdapterSubgraphPartitionConfigBuffer(
config_.NNAdapter().nnadapter_subgraph_partition_config_buffer); config_.NNAdapter().nnadapter_subgraph_partition_config_buffer);
argument_.SetNNAdapterSubgraphPartitionConfigPath( argument_->SetNNAdapterSubgraphPartitionConfigPath(
config_.NNAdapter().nnadapter_subgraph_partition_config_path); config_.NNAdapter().nnadapter_subgraph_partition_config_path);
std::vector<std::string> buffer_keys; std::vector<std::string> buffer_keys;
std::vector<std::vector<char>> buffer_vals; std::vector<std::vector<char>> buffer_vals;
...@@ -1187,67 +1189,67 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1187,67 +1189,67 @@ void AnalysisPredictor::PrepareArgument() {
buffer_keys.emplace_back(it.first); buffer_keys.emplace_back(it.first);
buffer_vals.emplace_back(it.second); buffer_vals.emplace_back(it.second);
} }
argument_.SetNNAdapterModelCacheToken(buffer_keys); argument_->SetNNAdapterModelCacheToken(buffer_keys);
argument_.SetNNAdapterModelCacheBuffer(buffer_vals); argument_->SetNNAdapterModelCacheBuffer(buffer_vals);
LOG(INFO) << "Lite subgraph engine is enabled"; LOG(INFO) << "Lite subgraph engine is enabled";
} }
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
argument_.SetUseIpu(config_.use_ipu_); argument_->SetUseIpu(config_.use_ipu_);
argument_.SetIpuDeviceNum(config_.ipu_device_num()); argument_->SetIpuDeviceNum(config_.ipu_device_num());
argument_.SetIpuMicroBatchSize(config_.ipu_micro_batch_size_); argument_->SetIpuMicroBatchSize(config_.ipu_micro_batch_size_);
argument_.SetIpuEnablePipelining(config_.ipu_enable_pipelining_); argument_->SetIpuEnablePipelining(config_.ipu_enable_pipelining_);
argument_.SetIpuBatchesPerStep(config_.ipu_batches_per_step_); argument_->SetIpuBatchesPerStep(config_.ipu_batches_per_step_);
argument_.SetIpuEnableFp16(config_.ipu_enable_fp16_); argument_->SetIpuEnableFp16(config_.ipu_enable_fp16_);
argument_.SetIpuReplicaNum(config_.ipu_replica_num_); argument_->SetIpuReplicaNum(config_.ipu_replica_num_);
argument_.SetIpuAvailableMemoryProportion( argument_->SetIpuAvailableMemoryProportion(
config_.ipu_available_memory_proportion_); config_.ipu_available_memory_proportion_);
argument_.SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_); argument_->SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_);
argument_.SetIpuEnableModelRuntimeExecutor( argument_->SetIpuEnableModelRuntimeExecutor(
config_.ipu_enable_model_runtime_executor_); config_.ipu_enable_model_runtime_executor_);
argument_.SetIpuCustomOpsInfo(config_.ipu_custom_ops_info_); argument_->SetIpuCustomOpsInfo(config_.ipu_custom_ops_info_);
argument_.SetIpuCustomPatterns(config_.ipu_custom_patterns_); argument_->SetIpuCustomPatterns(config_.ipu_custom_patterns_);
#endif #endif
argument_.SetUseNpu(config_.use_npu_); argument_->SetUseNpu(config_.use_npu_);
argument_.SetNPUDeviceId(config_.npu_device_id()); argument_->SetNPUDeviceId(config_.npu_device_id());
if (config_.use_mkldnn_) { if (config_.use_mkldnn_) {
LOG(INFO) << "MKLDNN is enabled"; LOG(INFO) << "MKLDNN is enabled";
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); argument_->SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
} }
if (config_.use_cinn_compiler_) { if (config_.use_cinn_compiler_) {
argument_.SetUseCinnCompiler(config_.use_cinn_compiler_); argument_->SetUseCinnCompiler(config_.use_cinn_compiler_);
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) { if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled"; LOG(INFO) << "Quantization is enabled";
argument_.SetQuantizeEnabledOpTypes( argument_->SetQuantizeEnabledOpTypes(
config_.mkldnn_quantizer_config()->enabled_op_types()); config_.mkldnn_quantizer_config()->enabled_op_types());
argument_.SetQuantizeExcludedOpIds( argument_->SetQuantizeExcludedOpIds(
config_.mkldnn_quantizer_config()->excluded_op_ids()); config_.mkldnn_quantizer_config()->excluded_op_ids());
} }
if (config_.use_mkldnn_bfloat16_) { if (config_.use_mkldnn_bfloat16_) {
LOG(INFO) << "Bfloat16 is enabled"; LOG(INFO) << "Bfloat16 is enabled";
argument_.SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_); argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
} }
if (config_.use_mkldnn_int8_) { if (config_.use_mkldnn_int8_) {
LOG(INFO) << "Int8 is enabled"; LOG(INFO) << "Int8 is enabled";
argument_.SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_); argument_->SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_.SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_); argument_->SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
argument_.SetQuantVarScales({}); argument_->SetQuantVarScales({});
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
argument_.SetUseCustomDevice(config_.use_custom_device()); argument_->SetUseCustomDevice(config_.use_custom_device());
if (config_.use_custom_device()) { if (config_.use_custom_device()) {
LOG(INFO) << "CustomDevice is enabled"; LOG(INFO) << "CustomDevice is enabled";
argument_.SetCustomDeviceType(config_.custom_device_type()); argument_->SetCustomDeviceType(config_.custom_device_type());
argument_.SetCustomDeviceId(config_.custom_device_id()); argument_->SetCustomDeviceId(config_.custom_device_id());
} }
#endif #endif
...@@ -1276,9 +1278,9 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1276,9 +1278,9 @@ void AnalysisPredictor::PrepareArgument() {
} }
if (!config_.ir_optim()) { if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false); argument_->SetEnableIrOptim(false);
if (config_.enable_gpu_mixed_) { if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true); argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses(); pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass"); pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO) LOG(INFO)
...@@ -1295,16 +1297,16 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1295,16 +1297,16 @@ void AnalysisPredictor::PrepareArgument() {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} }
} }
argument_.SetDisableLogs(config_.glog_info_disabled()); argument_->SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses()); argument_->SetIrAnalysisPasses(pass_builder->AllPasses());
argument_.SetAnalysisPasses(pass_builder->AnalysisPasses()); argument_->SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get()); argument_->SetScopeNotOwned(scope_.get());
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_->SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_->SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_); argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>( argument_->SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_))); paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
...@@ -1321,16 +1323,16 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -1321,16 +1323,16 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
} }
#endif #endif
Analyzer().Run(&argument_); Analyzer().Run(argument_.get());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
argument_.scope_valid(), argument_->scope_valid(),
true, true,
platform::errors::InvalidArgument("The argument scope should be valid.")); platform::errors::InvalidArgument("The argument scope should be valid."));
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program); ARGUMENT_CHECK_FIELD((argument_.get()), ir_analyzed_program);
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(argument_.ir_analyzed_program()), new framework::ProgramDesc(argument_->ir_analyzed_program()),
[](framework::ProgramDesc *prog) { [](framework::ProgramDesc *prog) {
// Note, please do NOT use any member variables, because member variables may // Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads. // have been destructed in multiple threads.
...@@ -1358,8 +1360,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -1358,8 +1360,17 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}); });
// The config and argument take a lot of storage, // The config and argument take a lot of storage,
// when the predictor settings are complete, we release these stores. // when the predictor settings are complete, we release these stores.
argument_.PartiallyRelease();
config_.PartiallyRelease(); config_.PartiallyRelease();
fusion_statis_ = *argument_->fusion_statis_ptr();
#if defined(_WIN32)
argument_->PartiallyRelease();
#else
if (config_.mkldnn_enabled() || config_.tensorrt_engine_enabled()) {
argument_->PartiallyRelease();
} else {
argument_.reset(nullptr);
}
#endif
LOG(INFO) << "======= optimize end ======="; LOG(INFO) << "======= optimize end =======";
} }
...@@ -2082,9 +2093,9 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() { ...@@ -2082,9 +2093,9 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
} }
std::string model_opt_cache_dir = std::string model_opt_cache_dir =
argument_.Has("model_dir") argument_->Has("model_dir") ? argument_->model_dir()
? argument_.model_dir() : inference::analysis::GetDirRoot(
: inference::analysis::GetDirRoot(argument_.model_program_path()); argument_->model_program_path());
std::string calibration_table_data_path = std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath( inference::analysis::GetTrtCalibPath(
......
...@@ -249,7 +249,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -249,7 +249,7 @@ class AnalysisPredictor : public PaddlePredictor {
/// ///
/// \return the argument obtained by config /// \return the argument obtained by config
/// ///
Argument &analysis_argument() { return argument_; } Argument &analysis_argument() { return *argument_; }
/// ///
/// \brief Clone to get the new predictor. thread safe. /// \brief Clone to get the new predictor. thread safe.
/// ///
...@@ -276,6 +276,13 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -276,6 +276,13 @@ class AnalysisPredictor : public PaddlePredictor {
/// ///
std::string GetSerializedProgram() const override; std::string GetSerializedProgram() const override;
///
/// \brief Get the fusion_statis_t
///
/// \return the fusion_statis_t
///
Argument::fusion_statis_t fusion_statis() { return fusion_statis_; }
/// ///
/// \brief Register a output hook function to operate the intermediate tensor /// \brief Register a output hook function to operate the intermediate tensor
/// of op output. when using this function, memory reuse should be tured off. /// of op output. when using this function, memory reuse should be tured off.
...@@ -484,7 +491,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -484,7 +491,8 @@ class AnalysisPredictor : public PaddlePredictor {
private: private:
AnalysisConfig config_; AnalysisConfig config_;
Argument argument_; std::unique_ptr<Argument> argument_;
Argument::fusion_statis_t fusion_statis_;
std::unique_ptr<NaiveExecutor> executor_; std::unique_ptr<NaiveExecutor> executor_;
platform::Place place_; platform::Place place_;
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
......
...@@ -588,15 +588,15 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { ...@@ -588,15 +588,15 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto& arg = predictor_.argument_; auto& arg = predictor_.argument_;
if (!arg.scope_valid()) arg.SetScope(new framework::Scope); if (!arg->scope_valid()) arg->SetScope(new framework::Scope);
arg.SetMainProgramNotOwned(predictor_.inference_program_.get()); arg->SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program())); auto graph = std::unique_ptr<Graph>(new Graph(arg->main_program()));
arg.SetMainGraph(graph.release()); arg->SetMainGraph(graph.release());
auto* scope_ptr = arg.scope_ptr(); auto* scope_ptr = arg->scope_ptr();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope_ptr, scope_ptr,
platform::errors::PreconditionNotMet("The scope should not be nullptr.")); platform::errors::PreconditionNotMet("The scope should not be nullptr."));
arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); arg->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
auto* builder = predictor_.config_.pass_builder(); auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({"cpu_quantize_pass", builder->SetPasses({"cpu_quantize_pass",
...@@ -605,10 +605,10 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -605,10 +605,10 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
"params_quantization_mkldnn_pass"}); "params_quantization_mkldnn_pass"});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_->SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses( predictor_.argument_->SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"}); {"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_); predictor_.argument_->SetQuantVarScales(scales_);
} }
bool AnalysisPredictor::MkldnnQuantizer::Quantize() { bool AnalysisPredictor::MkldnnQuantizer::Quantize() {
...@@ -628,15 +628,15 @@ bool AnalysisPredictor::MkldnnQuantizer::RunQuantizePasses() const { ...@@ -628,15 +628,15 @@ bool AnalysisPredictor::MkldnnQuantizer::RunQuantizePasses() const {
*predictor_.inference_program_, 0, true, predictor_.sub_scope_); *predictor_.inference_program_, 0, true, predictor_.sub_scope_);
PrepareArgument(); PrepareArgument();
auto& arg = predictor_.argument_; auto& arg = predictor_.argument_;
Analyzer().Run(&arg); Analyzer().Run(arg.get());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
arg.scope_valid(), arg->scope_valid(),
true, true,
platform::errors::PreconditionNotMet("The scope should be valid.")); platform::errors::PreconditionNotMet("The scope should be valid."));
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&arg), ir_analyzed_program); ARGUMENT_CHECK_FIELD(arg.get(), ir_analyzed_program);
predictor_.inference_program_.reset( predictor_.inference_program_.reset(
new framework::ProgramDesc(arg.ir_analyzed_program())); new framework::ProgramDesc(arg->ir_analyzed_program()));
LOG(INFO) << "== optimize 2 end =="; LOG(INFO) << "== optimize 2 end ==";
predictor_.executor_->CreateVariables( predictor_.executor_->CreateVariables(
*predictor_.inference_program_, 0, false, predictor_.sub_scope_); *predictor_.inference_program_, 0, false, predictor_.sub_scope_);
......
...@@ -78,10 +78,8 @@ TEST(Analyzer_Ernie, fuse_statis) { ...@@ -78,10 +78,8 @@ TEST(Analyzer_Ernie, fuse_statis) {
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
if (FLAGS_ernie_large) { if (FLAGS_ernie_large) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 146); ASSERT_EQ(fuse_statis.at("fc_fuse"), 146);
EXPECT_EQ(num_ops, 859);
} else { } else {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 74); ASSERT_EQ(fuse_statis.at("fc_fuse"), 74);
EXPECT_EQ(num_ops, 295);
} }
} }
......
...@@ -178,7 +178,6 @@ TEST(Analyzer_LAC, fuse_statis) { ...@@ -178,7 +178,6 @@ TEST(Analyzer_LAC, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse")); ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 4); EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 4);
EXPECT_EQ(num_ops, 11);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -161,7 +161,6 @@ TEST(Analyzer_Chinese_ner, fuse_statis) { ...@@ -161,7 +161,6 @@ TEST(Analyzer_Chinese_ner, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("fc_gru_fuse")); ASSERT_TRUE(fuse_statis.count("fc_gru_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 2); EXPECT_EQ(fuse_statis.at("fc_gru_fuse"), 2);
EXPECT_EQ(num_ops, 14);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -257,8 +257,6 @@ TEST(Analyzer_rnn1, fuse_statis) { ...@@ -257,8 +257,6 @@ TEST(Analyzer_rnn1, fuse_statis) {
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -161,7 +161,6 @@ TEST(Analyzer_seq_conv1, fuse_statis) { ...@@ -161,7 +161,6 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse")); ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2); EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6); EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 31);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -40,7 +40,6 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) { ...@@ -40,7 +40,6 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0); EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 181);
} }
} // namespace seq_pool1_tester } // namespace seq_pool1_tester
......
...@@ -372,23 +372,15 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor, ...@@ -372,23 +372,15 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
int *num_ops) { int *num_ops) {
std::unordered_map<std::string, int> res; std::unordered_map<std::string, int> res;
auto *analysis_predictor = static_cast<AnalysisPredictor *>(predictor); auto *analysis_predictor = static_cast<AnalysisPredictor *>(predictor);
auto *fusion_status = auto fusion_status = analysis_predictor->fusion_statis();
analysis_predictor->analysis_argument().fusion_statis_ptr(); if (fusion_status.empty()) {
if (!fusion_status) { fusion_status = res;
return res;
} }
for (auto &item : *fusion_status) { for (auto &item : fusion_status) {
LOG(INFO) << "fused " << item.first << " " << item.second; LOG(INFO) << "fused " << item.first << " " << item.second;
} }
int num = 0; *num_ops = 0;
for (auto &node : return fusion_status;
analysis_predictor->analysis_argument().main_graph().Nodes()) {
if (node->IsOp()) {
++num;
}
}
*num_ops = num;
return *fusion_status;
} }
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs, void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册