diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 175bc55dcff17e46aa47e1d2d187e3a8c8c4b43d..febfdec0b5cf500c30d44feccf4bed7e029feef4 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -282,6 +282,10 @@ struct Argument { DECL_ARGUMENT_FIELD(ipu_batch_size, IpuBatchSize, int); DECL_ARGUMENT_FIELD(ipu_need_avg_shard, IpuNeedAvgShard, bool); + // npu related + DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool); + DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int); + private: std::unordered_set valid_fields_; }; diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 06a353d5622a7093760c8680bcb8c1e245496ae8..daa18d8c78bf875ebcc6571bf955a7f634948e4f 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -22,16 +22,50 @@ namespace paddle { namespace inference { namespace analysis { -void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { - PADDLE_ENFORCE_EQ( - argument->scope_valid(), true, - platform::errors::PreconditionNotMet("The scope field should be valid")); - PADDLE_ENFORCE_EQ(argument->use_gpu_valid(), true, +#ifdef PADDLE_WITH_ASCEND_CL +void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) { + if (!argument->use_npu()) return; + + auto &graph = argument->main_graph(); + std::vector repetitive_params; + + if (graph.Has(framework::ir::kRepetitiveParamAttr)) + repetitive_params = graph.Get>( + framework::ir::kRepetitiveParamAttr); + + LOG(INFO) << "Sync params from CPU to NPU"; + + PADDLE_ENFORCE_EQ(argument->npu_device_id_valid(), true, platform::errors::PreconditionNotMet( - "The use_gpu field should be valid")); + "The npu_device_id field should be valid")); + platform::Place place = platform::NPUPlace(argument->npu_device_id()); + auto *scope = argument->scope_ptr(); + std::vector all_vars = scope->LocalVarNames(); - platform::Place place; + for (auto &var_name : all_vars) { + auto *var = scope->FindLocalVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( + "The var should not be nullptr")); + + if (var->IsType() || + var->IsType()) { + auto *t = var->GetMutable(); + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(t->dims()); + temp_tensor.mutable_data(cpu_place); + + paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); + t->clear(); + paddle::framework::TensorCopySync(temp_tensor, place, t); + } + } +} + +#else + +void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { // The parameters are on the cpu, therefore, synchronization is not necessary. if (!argument->use_gpu()) return; @@ -47,8 +81,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { PADDLE_ENFORCE_EQ(argument->gpu_device_id_valid(), true, platform::errors::PreconditionNotMet( "The gpu_device_id field should be valid")); - place = platform::CUDAPlace(argument->gpu_device_id()); - + platform::Place place = platform::CUDAPlace(argument->gpu_device_id()); auto *scope = argument->scope_ptr(); std::vector all_vars = scope->LocalVarNames(); @@ -100,6 +133,22 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { } } +#endif + +void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { + PADDLE_ENFORCE_EQ( + argument->scope_valid(), true, + platform::errors::PreconditionNotMet("The scope field should be valid")); + +#ifdef PADDLE_WITH_ASCEND_CL + if (!argument->use_npu_valid()) return; + CopyParamsToNpu(argument); +#else + if (!argument->use_gpu_valid()) return; + CopyParamsToGpu(argument); +#endif +} + std::string IrParamsSyncAmongDevicesPass::repr() const { return "ir-params-sync-among-devices-pass"; } diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h index 61990150a30db147418c4301359428cf3c6db541..d5e98ec886e65f829a1496b1431f23aad6c4bc4c 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h @@ -33,6 +33,13 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass { public: void RunImpl(Argument *argument) override; std::string repr() const override; + + private: +#ifdef PADDLE_WITH_ASCEND_CL + void CopyParamsToNpu(Argument *argument); +#else + void CopyParamsToGpu(Argument *argument); +#endif }; } // namespace analysis diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a86329a2b2b25df7cb256c47200598644af84bfe..628d974c1237862c81c9e124851004c50d07d377 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -668,6 +668,9 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetIpuBatchSize(config_.ipu_batch_size_); argument_.SetIpuNeedAvgShard(config_.ipu_need_avg_shard_); + argument_.SetUseNpu(config_.use_npu_); + argument_.SetNPUDeviceId(config_.npu_device_id()); + if (config_.use_mkldnn_) { LOG(INFO) << "MKLDNN is enabled"; argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);