diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 3948ca8a59fd59d72a3e3dce8003138ab65363a1..168b99f3d7649acbd4f0ec29a793018295ff12e8 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -21,9 +21,12 @@ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" #include "paddle/phi/common/data_type.h" namespace paddle { @@ -114,6 +117,28 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { reserve_cpu_weights = true; } + int64_t params_total_bytes{0}; + for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { + if (!node->IsOp()) continue; + if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue; + for (auto *var_node : node->inputs) { + if (!var_node->Var()->Persistable()) continue; + auto var_name = var_node->Var()->Name(); + auto *var = scope->FindLocalVar(var_name); + if (var->IsType() || + var->IsType()) { + auto *t = var->GetMutable(); + params_total_bytes += t->numel() * experimental::SizeOf(t->dtype()); + } + } + } + + { + // Alloc memory in pool to store all parameters. + framework::Tensor ts; + ts.mutable_data(place, params_total_bytes); + } + std::unordered_set visited; for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { if (!node->IsOp()) continue;