未验证 提交 5ed487bf 编写于 作者: R Ruibiao Chen 提交者: GitHub

Dispatch computation OPs before communication in standalone executor (#47471)

* Dispath computation OPs before communication in standalone executor

* Update code

* Fix CI errors
上级 6f7a80c3
......@@ -129,14 +129,13 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
}
}
bool IsCommunicationOp(const Instruction& instr) {
bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = {
"send",
"recv",
"send_v2",
"recv_v2",
};
const std::string& op_name = instr.OpBase()->Type();
const std::string communication_op_prefix = "c_";
if (op_name.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op_name)) {
......@@ -145,6 +144,10 @@ bool IsCommunicationOp(const Instruction& instr) {
return false;
}
bool IsCommunicationOp(const Instruction& instr) {
return IsCommunicationOp(instr.OpBase()->Type());
}
bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
......
......@@ -65,6 +65,8 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_;
};
bool IsCommunicationOp(const std::string& op_name);
bool IsCommunicationOp(const Instruction& instr);
bool IsCpuOp(const Instruction& instr);
......
......@@ -528,7 +528,12 @@ void InterpreterCore::Convert(
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
Priority priority =
interpreter::IsCommunicationOp(op_func_node.operator_base_->Type())
? Priority::kLowest
: Priority::kNormal;
vec_instruction_.emplace_back(
op_idx, std::move(op_func_node), *dev_ctx_, priority);
}
BuildOperatorDependences();
......@@ -835,7 +840,7 @@ void InterpreterCore::ExecuteInstructionList(
}
void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
const Instruction& instr, std::deque<size_t>* reserved_next_ops) {
platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10);
auto& next_instr = instr.NextInstructions();
......@@ -848,7 +853,7 @@ void InterpreterCore::RunNextInstructions(
if (instr.KernelType() == OpFuncType::kQueueAsync) {
// move all sync_ops into other threads
for (auto next_id : next_instr.SyncRunIds()) {
for (size_t next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
......@@ -856,14 +861,22 @@ void InterpreterCore::RunNextInstructions(
}
}
// keep all async_ops running in current thread
for (auto next_id : next_instr.DirectRunIds()) {
for (size_t next_id : next_instr.DirectRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
}
}
}
for (auto next_id : next_instr.EventRunIds()) {
for (size_t next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
}
}
}
} else {
......@@ -895,16 +908,18 @@ void InterpreterCore::RunNextInstructions(
[this, next_id] { RunInstructionAsync(next_id); });
}
}
if (first_op != -1) reserved_next_ops->push(first_op);
if (first_op != -1) {
reserved_next_ops->push_front(first_op);
}
}
}
void InterpreterCore::RunInstructionAsync(size_t instr_id) {
std::queue<size_t> ready_ops;
ready_ops.push(instr_id);
std::deque<size_t> ready_ops;
ready_ops.push_back(instr_id);
while (!ready_ops.empty()) {
instr_id = ready_ops.front();
ready_ops.pop();
ready_ops.pop_front();
auto& instr_node = vec_instruction_.at(instr_id);
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
......
......@@ -92,7 +92,7 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
std::deque<size_t>* reserved_next_ops);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
......
......@@ -673,8 +673,12 @@ void VariableScope::CheckExist(const std::string& name) const {
Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
const platform::DeviceContext& dev_ctx,
const Priority priority)
: id_(id),
op_func_node_(op_func_node),
dev_ctx_(dev_ctx),
priority_(priority) {
PADDLE_ENFORCE_GE(id,
0,
platform::errors::PreconditionNotMet(
......
......@@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";
enum class Priority { kLowest, kNormal };
class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
......@@ -300,7 +302,8 @@ class Instruction {
public:
Instruction(size_t id,
OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx);
const platform::DeviceContext& dev_ctx,
const Priority priority);
size_t Id() const;
......@@ -362,10 +365,13 @@ class Instruction {
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type);
Priority GetPriority() const { return priority_; }
private:
size_t id_;
OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned
const Priority priority_;
std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册