未验证 提交 73eacf3e 编写于 作者: G gongweibao 提交者: GitHub

Polish codes of old prs (#17981)

上级 fe43b2ee
...@@ -42,6 +42,9 @@ DEFINE_int32( ...@@ -42,6 +42,9 @@ DEFINE_int32(
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// unit of the FLAGS_fuse_parameter_memory_size.
static constexpr double kMB = 1048576.0;
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit // SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size' // test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test. // and 'FLAGS_fuse_parameter_groups_size' in unit test.
...@@ -228,8 +231,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -228,8 +231,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
VLOG(10) << out.str() VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size() << ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << ", group memory size:" << static_cast<double>(gps_size) / kMB
<< static_cast<double>(gps_size) / 1048576.0 << "(MB)"; << "(MB)";
} }
} }
...@@ -270,7 +273,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -270,7 +273,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
break; break;
} }
if (static_cast<double>(local_group_memory_size) / 1048576.0 >= if (static_cast<double>(local_group_memory_size) / kMB >=
group_memory_size) { group_memory_size) {
break; break;
} }
......
...@@ -164,6 +164,13 @@ def start_procs(args): ...@@ -164,6 +164,13 @@ def start_procs(args):
", node_ips:", node_ips, ", nranks:", nranks) ", node_ips:", node_ips, ", nranks:", nranks)
current_env = copy.copy(default_env) current_env = copy.copy(default_env)
# paddle broadcast ncclUniqueId use socket, and
# proxy maybe make trainers unreachable, so delete them.
# if we set them to "", grpc will log error message "bad uri"
# so just delete them.
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = [] procs = []
cmds = [] cmds = []
for i in range(0, selected_gpus_num): for i in range(0, selected_gpus_num):
...@@ -173,11 +180,7 @@ def start_procs(args): ...@@ -173,11 +180,7 @@ def start_procs(args):
"PADDLE_CURRENT_ENDPOINT": "PADDLE_CURRENT_ENDPOINT":
"%s:%d" % (current_node_ip, args.started_port + i), "%s:%d" % (current_node_ip, args.started_port + i),
"PADDLE_TRAINERS_NUM": "%d" % nranks, "PADDLE_TRAINERS_NUM": "%d" % nranks,
"PADDLE_TRAINER_ENDPOINTS": trainers_endpoints, "PADDLE_TRAINER_ENDPOINTS": trainers_endpoints
# paddle broadcast ncclUniqueId use socket, and
# proxy maybe make trainers unreachable, so set them to ""
"http_proxy": "",
"https_proxy": ""
}) })
cmd = [sys.executable, "-u", args.training_script cmd = [sys.executable, "-u", args.training_script
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册