未验证 提交 a058b474 编写于 作者: L Leo Chen 提交者: GitHub

add dependency for send/recv to support pp parallel (#41652)

上级 a688ae2e
...@@ -710,7 +710,12 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -710,7 +710,12 @@ std::map<int, std::list<int>> build_op_downstream_map(
// sequentially // sequentially
const std::set<std::string> random_op_set = { const std::set<std::string> random_op_set = {
"bernoulli", "poisson", "multinomial", "gaussian_random", "bernoulli", "poisson", "multinomial", "gaussian_random",
"uniform_random", "randint", "randperm", "exponential"}; "truncated_gaussian_random", "uniform_random", "randint", "randperm",
"exponential",
"sampling_id"
"dropout",
"class_center_sample",
};
int dependence_op_idx = -1; int dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
...@@ -723,13 +728,26 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -723,13 +728,26 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
// add dependency for communication op // add dependency for communication op
auto is_comm_op = [](std::string op) -> bool {
const std::set<std::string> special_comm_op_set = {
"send", "recv", "send_v2", "recv_v2",
};
const std::string communication_op_prefix = "c_"; const std::string communication_op_prefix = "c_";
if (op.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op)) {
return true;
}
return false;
};
dependence_op_idx = -1; dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type().find( if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
communication_op_prefix) != std::string::npos) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx); op2dependences[op_idx].insert(dependence_op_idx);
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
} }
dependence_op_idx = op_idx; dependence_op_idx = op_idx;
} }
...@@ -833,10 +851,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -833,10 +851,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size(); for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
++j) { ++j) {
if (j == target + 1 && if (j == target + 1 &&
vec_instruction[target].OpBase()->Type().find( is_comm_op(vec_instruction[target].OpBase()->Type()) &&
communication_op_prefix) != std::string::npos && is_comm_op(vec_instruction[j].OpBase()->Type())) {
vec_instruction[j].OpBase()->Type().find(communication_op_prefix) !=
std::string::npos) {
VLOG(4) << "Found consecutive communication ops, " VLOG(4) << "Found consecutive communication ops, "
<< vec_instruction[target].OpBase()->Type() << " -> " << vec_instruction[target].OpBase()->Type() << " -> "
<< vec_instruction[j].OpBase()->Type(); << vec_instruction[j].OpBase()->Type();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册