未验证 提交 b38d9cde 编写于 作者: C cheng cheng 提交者: GitHub

Feat: nccl_use_compute_stream support batch accumulation (#4618)

* NCCL logical refine timeshape

* Insert nccl ops after acc interface

* Inser NCCL ops after acc implement; need refine or add new acc_tick_op

* deadlock

* speed up and run

* add acc tick fix deadlocak ; and add nccl comm debug log

* refine log: rm cc_debug_log and cclog

* use reference for speed up

* refine code for review

* fix for review
Co-authored-by: NJuncheng <liujuncheng1022@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 c1180bd5
......@@ -82,6 +82,13 @@ bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) {
|| op_conf.has_acc_tick_conf()) {
return true;
}
if (op_conf.has_user_conf()) {
const std::string& user_type_name = op_conf.user_conf().op_type_name();
if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack"
|| user_type_name == "unpack") {
return true;
}
}
// NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops
if (!Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
&& IsOptimizerPassOp(op)) {
......@@ -108,10 +115,17 @@ bool CanBeMergedInChain(const TaskNode* node) {
return true;
}
std::shared_ptr<const Shape> GetTaskNodeTimeShape(const TaskNode* node) {
const auto* fw_comp_node = dynamic_cast<const NormalForwardCompTaskNode*>(node);
CHECK(fw_comp_node != nullptr);
return CHECK_JUST(fw_comp_node->op()->GetOpTimeShape());
}
void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) {
CHECK_NE(this_chain_id, -1);
CHECK_EQ(this_node->chain_id(), -1);
// bfs search all node can be merged in this chain
std::shared_ptr<const Shape> seed_time_shape = GetTaskNodeTimeShape(this_node);
HashSet<TaskNode*> visited_nodes;
std::queue<TaskNode*> queued_nodes;
queued_nodes.push(this_node);
......@@ -125,7 +139,8 @@ void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_
cur_node->ForEachNodeOnInOutEdge([&](TaskNode* next_node) {
if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node)
&& this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()) {
&& this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()
&& (*GetTaskNodeTimeShape(next_node)) == (*seed_time_shape)) {
if (next_node->chain_id() == -1) {
queued_nodes.push(next_node);
visited_nodes.insert(next_node);
......
......@@ -32,6 +32,14 @@ std::string GetNcclUniqueIdRpcKey(const std::vector<std::pair<int64_t, int64_t>>
return oss.str();
}
std::string NcclUniqueId2String(ncclUniqueId id) {
std::stringstream ss;
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(id.internal[i]);
}
return ss.str();
}
} // namespace
EagerNcclCommMgr::~EagerNcclCommMgr() {
......@@ -78,6 +86,8 @@ ncclComm_t EagerNcclCommMgr::GetCommForDevice(
});
}
ncclComm_t comm;
LOG(INFO) << " EagerNcclCommMgr::ncclCommInitRank device_vec.size() = " << device_vec.size()
<< ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank;
OF_NCCL_CHECK(ncclCommInitRank(&comm, device_vec.size(), nccl_unique_id, rank));
{
std::lock_guard<std::mutex> lock(mutex_);
......
......@@ -63,6 +63,45 @@ std::string ParallelDistributionToString(const ParallelDistribution& parallel_di
return serialized_parallel_distribution;
}
bool IsBreakpointOpNode(const OpNode* node) {
// NOTE(chengcheng): breakpoint op is special which CANNOT through subgraph such as:
// variable, tick, repeat/acc/pack/unpack change timeshape
const Operator& op = node->op();
const OperatorConf& op_conf = op.op_conf();
if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf()
|| op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf()
|| op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf()
|| op_conf.has_acc_tick_conf()) {
return true;
}
if (op_conf.has_user_conf()) {
const std::string& user_type_name = op_conf.user_conf().op_type_name();
if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack"
|| user_type_name == "unpack") {
return true;
}
}
return false;
}
bool IsAccOpNode(const OpNode* node) {
return node->op().op_conf().has_user_conf()
&& node->op().op_conf().user_conf().op_type_name() == "acc";
}
std::shared_ptr<const Shape> GetOpNodeTimeShape(const OpNode* op_node) {
return CHECK_JUST(op_node->op().GetOpTimeShape());
}
std::shared_ptr<const Shape> GetOpNodeInputTimeShape(const OpNode* op_node) {
return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape());
}
bool SharedPtrShapeEqual(const std::shared_ptr<const Shape>& lhs,
const std::shared_ptr<const Shape>& rhs) {
return (*lhs) == (*rhs);
}
void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const OpGraph& op_graph,
const std::vector<const OpNode*>& order) {
HashSet<const OpNode*> visited;
......@@ -74,12 +113,12 @@ void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const
// NOTE(chengcheng): ONLY consider GPU op and parallel num > 1.
if (seed_parallel_desc.device_type() != DeviceType::kGPU) { continue; }
if (seed_parallel_desc.parallel_num() <= 1) { continue; }
// NODE(chengcheng): Exclude op that change the time shape.
// like pack/unpack, repeat/acc, etc.
if (!seed_node->IsTimeShapeIdentity()) { continue; }
if (IsBreakpointOpNode(seed_node)) { continue; }
HashSet<const OpNode*> this_subgraph;
std::queue<const OpNode*> queued_nodes;
std::shared_ptr<const Shape> seed_time_shape = GetOpNodeTimeShape(seed_node);
queued_nodes.push(seed_node);
while (!queued_nodes.empty()) {
const OpNode* cur_node = queued_nodes.front();
......@@ -91,7 +130,7 @@ void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const
cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) {
if (visited.find(next_node) == visited.end()
&& next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)
&& next_node->IsTimeShapeIdentity()) {
&& SharedPtrShapeEqual(GetOpNodeTimeShape(next_node), seed_time_shape)) {
CHECK(visited.insert(next_node).second);
queued_nodes.push(next_node);
}
......@@ -376,8 +415,8 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode(
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
LOG(INFO) << "cc_debug_log: insert nccl op: " << nccl_op.name() << " from: ["
<< src_op_name << "](order=" << src_order << ", sbp_parallel_dis="
LOG(INFO) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name
<< "](order=" << src_order << ", sbp_parallel_dis="
<< ParallelDistributionToString(src_node->ParallelDistribution4Lbi(lbi))
<< ")->[" << dst_op_name << "](order=" << node2order.at(dst_node)
<< ", sbp_parallel_dis="
......@@ -437,10 +476,9 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
LOG(INFO) << "cc_debug_log: insert nccl op: " << nccl_op.name() << " from: ["
<< src_op_name << "](" << node2order.at(src_node) << ")->[" << dst_op_name
<< "](" << dst_order << ") and after: [" << pre_op_name << "](" << dst_order - 1
<< ")\n";
LOG(INFO) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name << "]("
<< node2order.at(src_node) << ")->[" << dst_op_name << "](" << dst_order
<< ") and after: [" << pre_op_name << "](" << dst_order - 1 << ")\n";
}
nccl_op_confs->push_back(nccl_op);
// NOTE(chengcheng, guoran): set nccl op as src_node parallel_conf (hierarchy) may check
......@@ -451,6 +489,118 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(
}
}
bool IsOpEdgeAllowInsertNccl(const OpEdge* edge,
const std::shared_ptr<const Shape>& seed_time_shape) {
const OpNode* src_node = edge->src_node();
const OpNode* dst_node = edge->dst_node();
const ParallelDesc& src_parallel_desc = src_node->parallel_desc();
return src_parallel_desc.device_type() == DeviceType::kGPU && src_parallel_desc.parallel_num() > 1
&& src_parallel_desc.EqualsIgnoringHierarchy(dst_node->parallel_desc())
&& SharedPtrShapeEqual(GetOpNodeTimeShape(src_node), seed_time_shape)
&& SharedPtrShapeEqual(GetOpNodeTimeShape(dst_node), seed_time_shape);
}
struct InsertedNcclInfo {
OperatorConf nccl_op_conf;
ParallelConf nccl_parallel_conf;
int64_t order;
std::string debug_str;
};
void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph,
const std::vector<const OpNode*>& ordered_op_nodes,
const std::vector<const OpNode*>& ordered_acc_op_nodes,
const std::string& bw_sink_tick_op_name,
HashMap<std::string, OperatorConf>* mut_consumer_name2op,
std::vector<OperatorConf>* nccl_op_confs,
std::vector<ParallelConf>* nccl_op_parallel_confs) {
HashMap<const OpNode*, int64_t> op_node2global_order;
op_node2global_order.reserve(ordered_op_nodes.size());
for (int64_t i = 0; i < ordered_op_nodes.size(); ++i) {
CHECK(op_node2global_order.emplace(ordered_op_nodes.at(i), i).second);
}
HashSet<const OpEdge*> visited;
std::shared_ptr<const Shape> seed_time_shape = GetOpNodeTimeShape(ordered_acc_op_nodes.front());
std::vector<InsertedNcclInfo> nccl_op_infos;
for (const OpNode* acc : ordered_acc_op_nodes) {
std::queue<const OpEdge*> queued_edges;
for (const OpEdge* op_edge : acc->out_edges()) {
if (IsOpEdgeAllowInsertNccl(op_edge, seed_time_shape)) {
queued_edges.push(op_edge);
CHECK(visited.insert(op_edge).second);
}
}
// bfs search each edge after acc allow insert nccl. try insert.
while (!queued_edges.empty()) {
const OpEdge* op_edge = queued_edges.front();
queued_edges.pop();
for (const LogicalBlobId& lbi : op_edge->lbis()) {
OperatorConf nccl_op;
if (!TryBuildNcclLogicalOpConf(&nccl_op, op_edge->src_node(), op_edge->dst_node(), lbi)) {
continue;
}
const OpNode* src_node = op_edge->src_node();
const OpNode* dst_node = op_edge->dst_node();
const std::string& src_op_name = src_node->op().op_name();
const std::string& dst_op_name = dst_node->op().op_name();
auto it = mut_consumer_name2op->find(dst_op_name);
if (it == mut_consumer_name2op->end()) {
auto ret_pair = mut_consumer_name2op->emplace(dst_op_name, dst_node->op().op_conf());
CHECK(ret_pair.second);
it = ret_pair.first;
}
// insert nccl op
user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op);
for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) {
std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(&(it->second), ibn,
nccl_op_wrapper.output("out", 0));
}
InsertedNcclInfo nccl_op_info;
nccl_op_info.nccl_op_conf = nccl_op;
nccl_op_info.nccl_parallel_conf = src_node->parallel_desc().parallel_conf();
nccl_op_info.order = op_node2global_order.at(src_node);
nccl_op_info.debug_str =
(" After ACC insert nccl op: " + nccl_op.name() + " from: [" + src_op_name + "]("
+ ParallelDistributionToString(src_node->ParallelDistribution4Lbi(lbi)) + ")->["
+ dst_op_name + "]("
+ ParallelDistributionToString(dst_node->ParallelDistribution4Lbi(lbi))
+ "), src_order = " + std::to_string(nccl_op_info.order) + "\n");
nccl_op_infos.push_back(nccl_op_info);
}
for (const OpEdge* dst_node_out_edge : op_edge->dst_node()->out_edges()) {
if (visited.find(dst_node_out_edge) == visited.end()
&& IsOpEdgeAllowInsertNccl(dst_node_out_edge, seed_time_shape)) {
CHECK(visited.insert(dst_node_out_edge).second);
queued_edges.push(dst_node_out_edge);
}
}
}
}
std::sort(nccl_op_infos.begin(), nccl_op_infos.end(),
[](const InsertedNcclInfo& lhs, const InsertedNcclInfo& rhs) {
return lhs.order < rhs.order;
});
for (int64_t i = 0; i < nccl_op_infos.size(); ++i) {
auto& info = nccl_op_infos.at(i);
if (i == 0) {
info.nccl_op_conf.add_ctrl_in_op_name(bw_sink_tick_op_name);
} else {
info.nccl_op_conf.add_ctrl_in_op_name(nccl_op_infos.at(i - 1).nccl_op_conf.name());
}
nccl_op_confs->push_back(info.nccl_op_conf);
nccl_op_parallel_confs->push_back(info.nccl_parallel_conf);
LOG(INFO) << info.debug_str;
}
}
Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node,
const std::function<void(OpNode*)>& Handler) {
......@@ -472,12 +622,16 @@ Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder*
std::vector<const OpNode*> subgraph_order;
HashMap<const OpNode*, int64_t> node2order;
std::vector<const OpNode*> ordered_acc_op_nodes;
for (const OpNode* this_node : ordered_op_nodes) {
if (subgraph.find(this_node) != subgraph.end()) {
subgraph_order.push_back(this_node);
node2order.emplace(this_node, subgraph_order.size() - 1);
} else if (IsAccOpNode(this_node)) {
ordered_acc_op_nodes.push_back(this_node);
}
}
CHECK_EQ(subgraph.size(), subgraph_order.size());
HashSet<std::string> mut_op_names;
......@@ -485,7 +639,7 @@ Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder*
HashMap<std::string, OperatorConf> subgraph_op_name2conf;
subgraph_op_name2conf.emplace(first_node->op().op_name(), first_node->op().op_conf());
auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
for (int32_t i = 1; i < subgraph_order.size(); ++i) {
for (int64_t i = 1; i < subgraph_order.size(); ++i) {
const OpNode* this_node = subgraph_order.at(i);
const OpNode* pre_node = subgraph_order.at(i - 1);
const std::string& this_op_name = this_node->op().op_name();
......@@ -499,7 +653,7 @@ Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder*
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
LOG(INFO) << "cc_debug_log: Try insert nccl logical ops into job: "
LOG(INFO) << " Try insert nccl logical ops into job: "
<< job_builder->job().job_conf().job_name() << ". Begin...\n";
}
......@@ -515,8 +669,76 @@ Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder*
subgraph_order, node2order);
}
if (!ordered_acc_op_nodes.empty()) {
const OpNode* bw_sink_op = subgraph_order.back();
const OpNode* first_acc_op = ordered_acc_op_nodes.front();
std::shared_ptr<const Shape> time_shape_before_acc = GetOpNodeTimeShape(bw_sink_op);
std::shared_ptr<const Shape> time_shape_after_acc = GetOpNodeTimeShape(first_acc_op);
LOG(WARNING) << " Find acc op in Job: " << job_builder->job().job_conf().job_name()
<< ", we will try insert special identity and ctrl for "
<< " UNSAFE handle ALL nccl ops between different time shape: "
<< time_shape_before_acc->DebugStr() << "->acc->"
<< time_shape_after_acc->DebugStr() << "\n\n";
CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt());
CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0);
for (const OpNode* acc : ordered_acc_op_nodes) {
CHECK(SharedPtrShapeEqual(time_shape_before_acc, GetOpNodeInputTimeShape(acc)));
CHECK(SharedPtrShapeEqual(time_shape_after_acc, GetOpNodeTimeShape(acc)));
}
// NOTE(chengcheng): insert acc_tick after bw_sink_op, and this tick op conf will control
// after_acc_nccl_ops start.
const auto& obns = bw_sink_op->op().output_bns();
CHECK(!obns.empty());
const std::string bw_sink_op_out_lbn =
GenLogicalBlobName(bw_sink_op->op().BnInOp2Lbi(obns.Get(0)));
LOG(INFO) << " bw_sink_op : " << bw_sink_op->op().op_conf().DebugString();
user_op::UserOpConfWrapper cast_to_tick_op =
user_op::UserOpConfWrapperBuilder("System-CastToTick-" + NewUniqueId())
.OpTypeName("cast_to_tick")
.Input("in", bw_sink_op_out_lbn)
.Output("out")
.Build();
job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf());
LOG(INFO) << " Insert cast_to_tick_op : " << cast_to_tick_op.op_conf().DebugString();
OperatorConf bw_sink_acc_tick_conf;
bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId());
auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf();
acc_conf->set_one(cast_to_tick_op.output("out", 0));
acc_conf->set_acc("acc");
acc_conf->set_max_acc_num(time_shape_before_acc->elem_cnt() / time_shape_after_acc->elem_cnt());
job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_acc_tick_conf);
LOG(INFO) << " Insert bw_sink_acc_tick_op : " << bw_sink_acc_tick_conf.DebugString();
OperatorConf bw_sink_final_tick_conf;
bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-Tick_") + NewUniqueId());
auto* tick_conf = bw_sink_final_tick_conf.mutable_tick_conf();
tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc"));
tick_conf->set_out("out");
job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_final_tick_conf);
LOG(INFO) << " Insert bw_sink_final_tick_op : " << bw_sink_final_tick_conf.DebugString();
// insert nccl ops after acc
std::vector<OperatorConf> after_acc_nccl_op_confs;
std::vector<ParallelConf> after_acc_nccl_parallel_confs;
HashMap<std::string, OperatorConf> mut_consumer_name2op;
InsertNcclLogicalOpsAfterAcc(op_graph, ordered_op_nodes, ordered_acc_op_nodes,
bw_sink_final_tick_conf.name(), &mut_consumer_name2op,
&after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs);
for (const auto& pair : mut_consumer_name2op) { JUST(job_builder->MutOpOnlyOnce(pair.second)); }
CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size());
for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) {
job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i));
}
}
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
LOG(INFO) << "cc_debug_log: Try insert nccl logical ops into job: "
LOG(INFO) << " Try insert nccl logical ops into job: "
<< job_builder->job().job_conf().job_name() << ". ...End\n\n";
}
......
......@@ -32,11 +32,24 @@ REGISTER_USER_OP("cast_to_tick")
*ctx->Dtype4ArgNameAndIndex("out", 0) = *ctx->Dtype4ArgNameAndIndex("in", 0);
return Maybe<void>::Ok();
})
.SetInferSbpSignatureFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe<void> {
SbpSignature* signature = ctx->mutable_sbp_signature();
auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel();
(*bn2sbp)[GenRepeatedBn("in", 0)] = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0);
(*bn2sbp)[GenRepeatedBn("out", 0)].mutable_broadcast_parallel();
.SetInferParallelDistributionFn([](user_op::InferParallelDistributionFnContext* ctx)
-> Maybe<void> {
const ParallelDistribution& in_dis_hint =
ctx->ParallelDistributionHint4InputArgNameAndIndex("in", 0);
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
CHECK_EQ(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes());
ParallelDistribution* in_distribution = ctx->ParallelDistribution4ArgNameAndIndex("in", 0);
ParallelDistribution* out_distribution = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
in_distribution->clear_sbp_parallel();
out_distribution->clear_sbp_parallel();
// in use hint
in_distribution->CopyFrom(in_dis_hint);
for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
// out dim1 = broadcast
out_distribution->add_sbp_parallel()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册