diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a72c1fe7622136ed80e2a98ed382c2b964f1937a..c386bdcb2e45cee3077fb20758b4cae3f4cd5744 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -371,6 +371,11 @@ struct Argument { // cinn compiler related DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool); + // custom device + DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool); + DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string); + DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int); + private: std::unordered_set valid_fields_; }; diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 8e6470b2c1a0bac336c1b332c9786fa5bbf21d4a..e3241d78e6bd2a6c128559436bc788767b2a059b 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" +#include #include #include @@ -26,6 +27,11 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/data_type.h" +DEFINE_bool( + custom_model_save_cpu, + false, + "Keep old mode for developers, the model is saved on cpu not device."); + namespace paddle { namespace inference { namespace analysis { @@ -71,9 +77,9 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) { } } } +#endif -#else - +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { // The parameters are on the cpu, therefore, synchronization is not necessary. if (!argument->use_gpu()) return; @@ -148,7 +154,62 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { } } } +#endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice( + Argument *argument) { + if (!argument->use_custom_device()) return; + + // On old mode, the model is saved on cpu not device. + if (argument->custom_device_type() == "OpenCL") { + PADDLE_ENFORCE_EQ( + FLAGS_custom_model_save_cpu, + false, + phi::errors::InvalidArgument( + "'FLAGS_custom_model_save_cpu = false' is only for the developers " + "who have not completed custom device memory settings. Setting to " + "true will make " + "model memory reserve on the cpu, and make inference slower.")); + } + + if (FLAGS_custom_model_save_cpu) return; + + auto &graph = argument->main_graph(); + std::vector repetitive_params; + + if (graph.Has(framework::ir::kRepetitiveParamAttr)) + repetitive_params = graph.Get>( + framework::ir::kRepetitiveParamAttr); + LOG(INFO) << "Sync params from CPU to CustomDevice" + << argument->custom_device_type() << "/" + << argument->custom_device_id(); + + platform::Place place = platform::CustomPlace(argument->custom_device_type(), + argument->custom_device_id()); + auto *scope = argument->scope_ptr(); + std::vector all_vars = scope->LocalVarNames(); + + for (auto &var_name : all_vars) { + auto *var = scope->FindLocalVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet("The var should not be nullptr")); + + if (var->IsType() || var->IsType()) { + auto *t = var->GetMutable(); + + platform::CPUPlace cpu_place; + phi::DenseTensor temp_tensor; + temp_tensor.Resize(t->dims()); + + paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); + t->clear(); + paddle::framework::TensorCopySync(temp_tensor, place, t); + } + } +} #endif void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { @@ -156,13 +217,20 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { argument->scope_valid(), true, platform::errors::PreconditionNotMet("The scope field should be valid")); - #ifdef PADDLE_WITH_ASCEND_CL - if (!argument->use_npu_valid()) return; - CopyParamsToNpu(argument); -#else - if (!argument->use_gpu_valid()) return; - CopyParamsToGpu(argument); + if (argument->use_npu_valid()) { + CopyParamsToNpu(argument); + } +#endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (argument->use_gpu_valid()) { + CopyParamsToGpu(argument); + } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (argument->use_custom_device_valid()) { + CopyParamsToCustomDevice(argument); + } #endif } diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h index d5e98ec886e65f829a1496b1431f23aad6c4bc4c..bc91bd6a1aea18b87382c9d05a2be9060c05c527 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h @@ -37,9 +37,15 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass { private: #ifdef PADDLE_WITH_ASCEND_CL void CopyParamsToNpu(Argument *argument); -#else +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void CopyParamsToGpu(Argument *argument); #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + void CopyParamsToCustomDevice(Argument *argument); +#endif }; } // namespace analysis diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index af4d83f55a6ee2fef289e9e693fb7190cfcb2c4a..983506c7c02bcf5a085f67de9396d6632854ed2a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1242,6 +1242,15 @@ void AnalysisPredictor::PrepareArgument() { } #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + argument_.SetUseCustomDevice(config_.use_custom_device()); + if (config_.use_custom_device()) { + LOG(INFO) << "CustomDevice is enabled"; + argument_.SetCustomDeviceType(config_.custom_device_type()); + argument_.SetCustomDeviceId(config_.custom_device_id()); + } +#endif + auto *pass_builder = config_.pass_builder(); // TODO(inference): Need to reconstruct the pass_builder, pass should be // processed in a single