未验证 提交 a058b474 编写于 作者: L Leo Chen 提交者: GitHub

add dependency for send/recv to support pp parallel (#41652)

上级 a688ae2e
......@@ -709,8 +709,13 @@ std::map<int, std::list<int>> build_op_downstream_map(
// add dependences for random op, make sure that the random op is scheduled
// sequentially
const std::set<std::string> random_op_set = {
"bernoulli", "poisson", "multinomial", "gaussian_random",
"uniform_random", "randint", "randperm", "exponential"};
"bernoulli", "poisson", "multinomial", "gaussian_random",
"truncated_gaussian_random", "uniform_random", "randint", "randperm",
"exponential",
"sampling_id"
"dropout",
"class_center_sample",
};
int dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
......@@ -723,13 +728,26 @@ std::map<int, std::list<int>> build_op_downstream_map(
}
// add dependency for communication op
const std::string communication_op_prefix = "c_";
auto is_comm_op = [](std::string op) -> bool {
const std::set<std::string> special_comm_op_set = {
"send", "recv", "send_v2", "recv_v2",
};
const std::string communication_op_prefix = "c_";
if (op.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op)) {
return true;
}
return false;
};
dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type().find(
communication_op_prefix) != std::string::npos) {
if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx);
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
}
dependence_op_idx = op_idx;
}
......@@ -833,10 +851,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
++j) {
if (j == target + 1 &&
vec_instruction[target].OpBase()->Type().find(
communication_op_prefix) != std::string::npos &&
vec_instruction[j].OpBase()->Type().find(communication_op_prefix) !=
std::string::npos) {
is_comm_op(vec_instruction[target].OpBase()->Type()) &&
is_comm_op(vec_instruction[j].OpBase()->Type())) {
VLOG(4) << "Found consecutive communication ops, "
<< vec_instruction[target].OpBase()->Type() << " -> "
<< vec_instruction[j].OpBase()->Type();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册