diff --git a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc index 9e1cf5832bdc11089c0aacd9b9005dfb788dc764..4da889b400071f9bb4bb6476d25e2ba5957ea2ee 100644 --- a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc +++ b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc @@ -23,8 +23,8 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_registry.h" -DEFINE_uint64(fuse_parameter_memory_size, 0, // Bytes - "fuse_parameter_memory_size is up limited memory size " +DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes + "fuse_parameter_memory_size is up limited memory size(MB)" "of one group parameters' gradient which is the input " "of communication calling(e.g NCCLAllReduce). " "The default value is 0, it means that " @@ -51,13 +51,11 @@ void SetFuseParameterGroupsSize(int group_size) { int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; } -void SetFuseParameterMemorySize(uint64_t memory_size) { +void SetFuseParameterMemorySize(double memory_size) { FLAGS_fuse_parameter_memory_size = memory_size; } -uint64_t GetFuseParameterMemorySize() { - return FLAGS_fuse_parameter_memory_size; -} +double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; } static framework::proto::VarType::Type kDefaultDtype = framework::proto::VarType::Type::VarType_Type_BOOL; @@ -230,15 +228,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { } VLOG(10) << out.str() << ", group size:" << group_grads_params->at(i).size() - << ", group memory size:" << gps_size; + << ", group memory size:" + << static_cast(gps_size) / 1048576.0 << "(MB)"; } } void SetGroupAccordingToMemorySize( const std::unordered_map &var_nodes, details::GroupGradsAndParams *group_grads_params) const { - const uint64_t group_memory_size = GetFuseParameterMemorySize(); - if (group_memory_size == 0) { + const double group_memory_size = GetFuseParameterMemorySize(); + if (group_memory_size <= 0.0) { return; } details::GroupGradsAndParams local_group_grads_params; @@ -271,7 +270,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { break; } - if (local_group_memory_size >= group_memory_size) { + if (static_cast(local_group_memory_size) / 1048576.0 >= + group_memory_size) { break; } } @@ -280,7 +280,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { std::swap(*group_grads_params, local_group_grads_params); VLOG(10) << string::Sprintf( - "SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size); + "SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size); if (VLOG_IS_ON(10)) { PrintGroupInfo(var_nodes, group_grads_params); diff --git a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h index b20eda96f0fb622ccd318d9418ddb15f2997f8e6..38dc4c99fc27f03d64704b479478065b636af63a 100644 --- a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h +++ b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h @@ -21,8 +21,8 @@ namespace ir { void SetFuseParameterGroupsSize(int group_size); int GetFuseParameterGroupsSize(); -void SetFuseParameterMemorySize(uint64_t memory_size); -uint64_t GetFuseParameterMemorySize(); +void SetFuseParameterMemorySize(double memory_size); +double GetFuseParameterMemorySize(); } // namespace ir } // namespace framework