diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 59703332efe9594c3f2130eeff79bea6c690839e..63fcf0cffaa84c47412958da3c5a44f42a27a091 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -709,8 +709,13 @@ std::map> build_op_downstream_map( // add dependences for random op, make sure that the random op is scheduled // sequentially const std::set random_op_set = { - "bernoulli", "poisson", "multinomial", "gaussian_random", - "uniform_random", "randint", "randperm", "exponential"}; + "bernoulli", "poisson", "multinomial", "gaussian_random", + "truncated_gaussian_random", "uniform_random", "randint", "randperm", + "exponential", + "sampling_id" + "dropout", + "class_center_sample", + }; int dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { @@ -723,13 +728,26 @@ std::map> build_op_downstream_map( } // add dependency for communication op - const std::string communication_op_prefix = "c_"; + auto is_comm_op = [](std::string op) -> bool { + const std::set special_comm_op_set = { + "send", "recv", "send_v2", "recv_v2", + }; + const std::string communication_op_prefix = "c_"; + if (op.find(communication_op_prefix) != std::string::npos || + special_comm_op_set.count(op)) { + return true; + } + return false; + }; + dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { - if (vec_instruction[op_idx].OpBase()->Type().find( - communication_op_prefix) != std::string::npos) { + if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) { if (dependence_op_idx != -1) { op2dependences[op_idx].insert(dependence_op_idx); + VLOG(4) << "Add depend from " + << vec_instruction[dependence_op_idx].OpBase()->Type() << " to " + << vec_instruction[op_idx].OpBase()->Type(); } dependence_op_idx = op_idx; } @@ -833,10 +851,8 @@ std::map> build_op_downstream_map( for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size(); ++j) { if (j == target + 1 && - vec_instruction[target].OpBase()->Type().find( - communication_op_prefix) != std::string::npos && - vec_instruction[j].OpBase()->Type().find(communication_op_prefix) != - std::string::npos) { + is_comm_op(vec_instruction[target].OpBase()->Type()) && + is_comm_op(vec_instruction[j].OpBase()->Type())) { VLOG(4) << "Found consecutive communication ops, " << vec_instruction[target].OpBase()->Type() << " -> " << vec_instruction[j].OpBase()->Type();