未验证 提交 972c54cd 编写于 作者: G gongweibao 提交者: GitHub

Fix FLAGS_fuse_parameter_memory_size unit from Bytes to MBytes. (#17924)

上级 0a96ec69
......@@ -23,8 +23,8 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // Bytes
"fuse_parameter_memory_size is up limited memory size "
DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes
"fuse_parameter_memory_size is up limited memory size(MB)"
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
......@@ -51,13 +51,11 @@ void SetFuseParameterGroupsSize(int group_size) {
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
void SetFuseParameterMemorySize(double memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
......@@ -230,15 +228,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << gps_size;
<< ", group memory size:"
<< static_cast<double>(gps_size) / 1048576.0 << "(MB)";
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
const double group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size <= 0.0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
......@@ -271,7 +270,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
break;
}
if (local_group_memory_size >= group_memory_size) {
if (static_cast<double>(local_group_memory_size) / 1048576.0 >=
group_memory_size) {
break;
}
}
......@@ -280,7 +280,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
"SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size);
if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
......
......@@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
void SetFuseParameterMemorySize(double memory_size);
double GetFuseParameterMemorySize();
} // namespace ir
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册