未验证 提交 25a88f21 编写于 作者: C cheng cheng 提交者: GitHub

disable_group_boxing and change nccl logical order to dst (#4236)

* disable_group_boxing and change nccl logical order to dst

* remove note

* both support insert nccl logical ops as close as possible to Src/Dst node
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 d5b2fee2
......@@ -44,4 +44,5 @@ message Resource {
// NOTE(chengcheng) to reuse nccl memory and speed up
optional bool nccl_use_compute_stream = 30 [default = false];
optional bool disable_group_boxing_by_dst_parallel = 31 [default = false];
}
......@@ -158,6 +158,117 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const
return false;
}
bool ReverseOrderInsertNcclLogicalOps() {
return Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel();
}
void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode(
HashMap<std::string, OperatorConf>* subgraph_op_name2conf, HashSet<std::string>* mut_op_names,
std::vector<OperatorConf>* nccl_op_confs, const std::vector<const OpNode*>& subgraph_order,
const HashMap<const OpNode*, int64_t>& node2order) {
for (const OpNode* src_node : subgraph_order) {
const std::string& src_op_name = src_node->op().op_name();
for (const OpEdge* op_edge : src_node->out_edges()) {
const OpNode* dst_node = op_edge->dst_node();
const std::string& dst_op_name = dst_node->op().op_name();
CHECK(src_node != dst_node);
if (subgraph_op_name2conf->find(dst_op_name) == subgraph_op_name2conf->end()) {
// NOTE(chengcheng): child node is not in this subgraph.
continue;
}
for (const LogicalBlobId& lbi : op_edge->lbis()) {
OperatorConf nccl_op;
if (!TryBuildNcclLogicalOpConf(&nccl_op, src_node, dst_node, lbi)) { continue; }
mut_op_names->insert(dst_op_name);
// insert nccl op
user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op);
for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {
std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(
&subgraph_op_name2conf->at(dst_op_name), ibn, nccl_op_wrapper.output("out", 0));
}
if (nccl_op_confs->size() >= 1) {
// NOTE(chengcheng): MUST add ctrl edge between nccl ops for 1 src node insert multi-nccl
const std::string& pre_nccl_op_name = nccl_op_confs->at(nccl_op_confs->size() - 1).name();
nccl_op.add_ctrl_in_op_name(pre_nccl_op_name);
}
// NOTE(chengcheng): src_node MUST not the last node in subgraph, find the next op
int64_t src_order = node2order.at(src_node);
CHECK(src_order + 1 < subgraph_order.size());
const std::string& next_op_name = subgraph_order.at(src_order + 1)->op().op_name();
if (dst_op_name != next_op_name) {
// NOTE(chengcheng): MUST add ctrl edge for strict exec order
subgraph_op_name2conf->at(next_op_name).add_ctrl_in_op_name(nccl_op.name());
mut_op_names->insert(next_op_name);
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
std::cout << "cc_debug_log: insert nccl op from: [" << src_op_name << "](" << src_order
<< ")->[" << dst_op_name << "](" << node2order.at(dst_node) << ") and before: ["
<< next_op_name << "](" << src_order + 1 << ")\n";
}
nccl_op_confs->push_back(nccl_op);
}
}
}
}
void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(
HashMap<std::string, OperatorConf>* subgraph_op_name2conf, HashSet<std::string>* mut_op_names,
std::vector<OperatorConf>* nccl_op_confs, const std::vector<const OpNode*>& subgraph_order,
const HashMap<const OpNode*, int64_t>& node2order) {
for (const OpNode* dst_node : subgraph_order) {
const std::string& dst_op_name = dst_node->op().op_name();
for (const OpEdge* op_edge : dst_node->in_edges()) {
const OpNode* src_node = op_edge->src_node();
const std::string& src_op_name = src_node->op().op_name();
CHECK(src_node != dst_node);
if (subgraph_op_name2conf->find(src_op_name) == subgraph_op_name2conf->end()) {
// NOTE(chengcheng): parent node is not in this subgraph.
continue;
}
for (const LogicalBlobId& lbi : op_edge->lbis()) {
OperatorConf nccl_op;
// builde nccl op
if (!TryBuildNcclLogicalOpConf(&nccl_op, src_node, dst_node, lbi)) { continue; }
mut_op_names->insert(dst_op_name);
// insert nccl op
user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op);
for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {
std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(
&subgraph_op_name2conf->at(dst_op_name), ibn, nccl_op_wrapper.output("out", 0));
CHECK(old_lbn == GenLogicalBlobName(lbi));
}
// add necessary ctrl edge for strict order
if (nccl_op_confs->size() >= 1) {
// NOTE(chengcheng): MUST add ctrl edge between nccl ops for 1 dst node insert multi-nccl
const std::string& pre_nccl_op_name = nccl_op_confs->at(nccl_op_confs->size() - 1).name();
nccl_op.add_ctrl_in_op_name(pre_nccl_op_name);
}
// NOTE(chengcheng): dst_node MUST not the first node in subgraph, find the Immediately
// previous op of dst_node.
int64_t dst_order = node2order.at(dst_node);
CHECK_GT(dst_order, 0);
const std::string& pre_op_name = subgraph_order.at(dst_order - 1)->op().op_name();
if (src_op_name != pre_op_name) {
// NOTE(chengcheng): MUST add ctrl edge for strict exec order
nccl_op.add_ctrl_in_op_name(pre_op_name);
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
std::cout << "cc_debug_log: insert nccl op from: [" << src_op_name << "]("
<< node2order.at(src_node) << ")->[" << dst_op_name << "](" << dst_order
<< ") and after: [" << pre_op_name << "](" << dst_order - 1 << ")\n";
}
nccl_op_confs->push_back(nccl_op);
}
}
}
}
Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node,
const std::function<void(OpNode*)>& Handler) {
......@@ -205,46 +316,23 @@ Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder*
}
}
std::vector<OperatorConf> nccl_op_confs;
for (const OpNode* src_node : subgraph_order) {
for (const OpEdge* op_edge : src_node->out_edges()) {
const OpNode* dst_node = op_edge->dst_node();
const std::string& dst_op_name = dst_node->op().op_name();
CHECK(src_node != dst_node);
if (subgraph_op_name2conf.find(dst_op_name) == subgraph_op_name2conf.end()) {
// NOTE(chengcheng): child node is not in this subgraph.
continue;
}
for (const LogicalBlobId& lbi : op_edge->lbis()) {
OperatorConf nccl_op;
if (!TryBuildNcclLogicalOpConf(&nccl_op, src_node, dst_node, lbi)) { continue; }
mut_op_names.insert(dst_op_name);
// insert nccl op
user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op);
for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {
std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(
&subgraph_op_name2conf.at(dst_op_name), ibn, nccl_op_wrapper.output("out", 0));
}
if (nccl_op_confs.size() >= 1) {
// NOTE(chengcheng): MUST add ctrl edge between nccl ops for 1 src node insert multi-nccl
const std::string& pre_nccl_op_name = nccl_op_confs.at(nccl_op_confs.size() - 1).name();
nccl_op.add_ctrl_in_op_name(pre_nccl_op_name);
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
std::cout << "cc_debug_log: Try insert nccl logical ops into job: "
<< job_builder->job().job_conf().job_name() << ". Begin...\n";
}
// NOTE(chengcheng): src_node MUST not the last node in subgraph, find the next op
int64_t src_order = node2order.at(src_node);
CHECK(src_order + 1 < subgraph_order.size());
const std::string& next_op_name = subgraph_order.at(src_order + 1)->op().op_name();
if (dst_op_name != next_op_name) {
// NOTE(chengcheng): MUST add ctrl edge for strict exec order
subgraph_op_name2conf.at(next_op_name).add_ctrl_in_op_name(nccl_op.name());
mut_op_names.insert(next_op_name);
}
std::vector<OperatorConf> nccl_op_confs;
if (ReverseOrderInsertNcclLogicalOps()) {
InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names,
&nccl_op_confs, subgraph_order, node2order);
} else {
InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode(&subgraph_op_name2conf, &mut_op_names,
&nccl_op_confs, subgraph_order, node2order);
}
nccl_op_confs.push_back(nccl_op);
}
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
std::cout << "cc_debug_log: Try insert nccl logical ops into job: "
<< job_builder->job().job_conf().job_name() << ". ...End\n\n";
}
std::vector<OperatorConf> mut_op_confs;
......
......@@ -96,7 +96,10 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
void JobCompleter::Complete(Job* job) const {
JobPassCtx job_pass_ctx(GlobalJobDesc());
JobPass4Name("DumpTimeShapeAndBlobParallelConfPass")(job, &job_pass_ctx);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
// NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost.
if (!Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel()) {
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
}
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp);
// complete tick ops
WithOpGraphAndMutJobBuilder(job, &AutoPrependTick);
......
......@@ -464,6 +464,23 @@ def nccl_use_compute_stream(val=False):
sess.config_proto.resource.nccl_use_compute_stream = val
@oneflow_export("config.disable_group_boxing_by_dst_parallel")
def api_disable_group_boxing_by_dst_parallel(val: bool = False) -> None:
r"""Whether or not disable group boxing by dst parallel pass to reduce boxing memory life cycle.
Args:
val (bool, optional): True or False. Defaults to False.
"""
return enable_if.unique([disable_group_boxing_by_dst_parallel, do_nothing])(val=val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def disable_group_boxing_by_dst_parallel(val=False):
sess = session_ctx.GetDefaultSession()
assert type(val) is bool
sess.config_proto.resource.disable_group_boxing_by_dst_parallel = val
@oneflow_export("config.collective_boxing.nccl_num_streams")
def api_nccl_num_streams(val: int) -> None:
r"""Set up the number of nccl parallel streams while use boxing
......
......@@ -28,6 +28,7 @@ def _test_split_to_split_enable_all_to_all(
flow.clear_default_session()
flow.config.gpu_device_num(2)
flow.config.nccl_use_compute_stream(True)
flow.config.disable_group_boxing_by_dst_parallel(True)
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_logical_view(flow.scope.consistent_view())
......@@ -50,6 +51,7 @@ def _test_split_to_broadcast(
flow.clear_default_session()
flow.config.gpu_device_num(2)
flow.config.nccl_use_compute_stream(True)
flow.config.disable_group_boxing_by_dst_parallel(True)
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_logical_view(flow.scope.consistent_view())
......@@ -72,6 +74,7 @@ def _test_partial_sum_to_split(
flow.clear_default_session()
flow.config.gpu_device_num(2)
flow.config.nccl_use_compute_stream(True)
flow.config.disable_group_boxing_by_dst_parallel(True)
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_logical_view(flow.scope.consistent_view())
......@@ -93,6 +96,7 @@ def _test_partial_sum_to_broadcast(test_case):
flow.clear_default_session()
flow.config.gpu_device_num(2)
flow.config.nccl_use_compute_stream(True)
flow.config.disable_group_boxing_by_dst_parallel(True)
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_logical_view(flow.scope.consistent_view())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册