未验证 提交 972c54cd 编写于 作者: G gongweibao 提交者: GitHub

Fix FLAGS_fuse_parameter_memory_size unit from Bytes to MBytes. (#17924)

上级 0a96ec69
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // Bytes DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes
"fuse_parameter_memory_size is up limited memory size " "fuse_parameter_memory_size is up limited memory size(MB)"
"of one group parameters' gradient which is the input " "of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). " "of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that " "The default value is 0, it means that "
...@@ -51,13 +51,11 @@ void SetFuseParameterGroupsSize(int group_size) { ...@@ -51,13 +51,11 @@ void SetFuseParameterGroupsSize(int group_size) {
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; } int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) { void SetFuseParameterMemorySize(double memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size; FLAGS_fuse_parameter_memory_size = memory_size;
} }
uint64_t GetFuseParameterMemorySize() { double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }
return FLAGS_fuse_parameter_memory_size;
}
static framework::proto::VarType::Type kDefaultDtype = static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL; framework::proto::VarType::Type::VarType_Type_BOOL;
...@@ -230,15 +228,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -230,15 +228,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
VLOG(10) << out.str() VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size() << ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << gps_size; << ", group memory size:"
<< static_cast<double>(gps_size) / 1048576.0 << "(MB)";
} }
} }
void SetGroupAccordingToMemorySize( void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes, const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const { details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize(); const double group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) { if (group_memory_size <= 0.0) {
return; return;
} }
details::GroupGradsAndParams local_group_grads_params; details::GroupGradsAndParams local_group_grads_params;
...@@ -271,7 +270,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -271,7 +270,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
break; break;
} }
if (local_group_memory_size >= group_memory_size) { if (static_cast<double>(local_group_memory_size) / 1048576.0 >=
group_memory_size) {
break; break;
} }
} }
...@@ -280,7 +280,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -280,7 +280,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
std::swap(*group_grads_params, local_group_grads_params); std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf( VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size); "SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params); PrintGroupInfo(var_nodes, group_grads_params);
......
...@@ -21,8 +21,8 @@ namespace ir { ...@@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size); void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize(); int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size); void SetFuseParameterMemorySize(double memory_size);
uint64_t GetFuseParameterMemorySize(); double GetFuseParameterMemorySize();
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册