From a2b2af90593d0e45e7b122c81c6f426b39b066af Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 7 Sep 2022 14:29:13 +0800 Subject: [PATCH] Optimiza params sync between CPU and GPU. (#45805) * enable memory optimize when fp16. * optimiza params sync between cpu and gpu. --- .../ir_params_sync_among_devices_pass.cc | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 3948ca8a59f..168b99f3d76 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -21,9 +21,12 @@ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" #include "paddle/phi/common/data_type.h" namespace paddle { @@ -114,6 +117,28 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { reserve_cpu_weights = true; } + int64_t params_total_bytes{0}; + for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { + if (!node->IsOp()) continue; + if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue; + for (auto *var_node : node->inputs) { + if (!var_node->Var()->Persistable()) continue; + auto var_name = var_node->Var()->Name(); + auto *var = scope->FindLocalVar(var_name); + if (var->IsType() || + var->IsType()) { + auto *t = var->GetMutable(); + params_total_bytes += t->numel() * experimental::SizeOf(t->dtype()); + } + } + } + + { + // Alloc memory in pool to store all parameters. + framework::Tensor ts; + ts.mutable_data(place, params_total_bytes); + } + std::unordered_set visited; for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { if (!node->IsOp()) continue; -- GitLab