未验证 提交 106b5514 编写于 作者: B baoachun 提交者: GitHub

support npu weight unified H2D copy before inference (#39160)

* support npu weight unified H2D copy

* remove redundant variable
上级 b1a458ac
......@@ -282,6 +282,10 @@ struct Argument {
DECL_ARGUMENT_FIELD(ipu_batch_size, IpuBatchSize, int);
DECL_ARGUMENT_FIELD(ipu_need_avg_shard, IpuNeedAvgShard, bool);
// npu related
DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool);
DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int);
private:
std::unordered_set<std::string> valid_fields_;
};
......
......@@ -22,16 +22,50 @@ namespace paddle {
namespace inference {
namespace analysis {
void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
PADDLE_ENFORCE_EQ(
argument->scope_valid(), true,
platform::errors::PreconditionNotMet("The scope field should be valid"));
PADDLE_ENFORCE_EQ(argument->use_gpu_valid(), true,
#ifdef PADDLE_WITH_ASCEND_CL
void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) {
if (!argument->use_npu()) return;
auto &graph = argument->main_graph();
std::vector<std::string> repetitive_params;
if (graph.Has(framework::ir::kRepetitiveParamAttr))
repetitive_params = graph.Get<std::vector<std::string>>(
framework::ir::kRepetitiveParamAttr);
LOG(INFO) << "Sync params from CPU to NPU";
PADDLE_ENFORCE_EQ(argument->npu_device_id_valid(), true,
platform::errors::PreconditionNotMet(
"The use_gpu field should be valid"));
"The npu_device_id field should be valid"));
platform::Place place = platform::NPUPlace(argument->npu_device_id());
auto *scope = argument->scope_ptr();
std::vector<std::string> all_vars = scope->LocalVarNames();
platform::Place place;
for (auto &var_name : all_vars) {
auto *var = scope->FindLocalVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
"The var should not be nullptr"));
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto *t = var->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
temp_tensor.mutable_data<float>(cpu_place);
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
#else
void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
// The parameters are on the cpu, therefore, synchronization is not necessary.
if (!argument->use_gpu()) return;
......@@ -47,8 +81,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
PADDLE_ENFORCE_EQ(argument->gpu_device_id_valid(), true,
platform::errors::PreconditionNotMet(
"The gpu_device_id field should be valid"));
place = platform::CUDAPlace(argument->gpu_device_id());
platform::Place place = platform::CUDAPlace(argument->gpu_device_id());
auto *scope = argument->scope_ptr();
std::vector<std::string> all_vars = scope->LocalVarNames();
......@@ -100,6 +133,22 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
}
}
#endif
void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
PADDLE_ENFORCE_EQ(
argument->scope_valid(), true,
platform::errors::PreconditionNotMet("The scope field should be valid"));
#ifdef PADDLE_WITH_ASCEND_CL
if (!argument->use_npu_valid()) return;
CopyParamsToNpu(argument);
#else
if (!argument->use_gpu_valid()) return;
CopyParamsToGpu(argument);
#endif
}
std::string IrParamsSyncAmongDevicesPass::repr() const {
return "ir-params-sync-among-devices-pass";
}
......
......@@ -33,6 +33,13 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override;
private:
#ifdef PADDLE_WITH_ASCEND_CL
void CopyParamsToNpu(Argument *argument);
#else
void CopyParamsToGpu(Argument *argument);
#endif
};
} // namespace analysis
......
......@@ -668,6 +668,9 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetIpuBatchSize(config_.ipu_batch_size_);
argument_.SetIpuNeedAvgShard(config_.ipu_need_avg_shard_);
argument_.SetUseNpu(config_.use_npu_);
argument_.SetNPUDeviceId(config_.npu_device_id());
if (config_.use_mkldnn_) {
LOG(INFO) << "MKLDNN is enabled";
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册