未验证 提交 5ed487bf 编写于 作者: R Ruibiao Chen 提交者: GitHub

Dispatch computation OPs before communication in standalone executor (#47471)

* Dispath computation OPs before communication in standalone executor

* Update code

* Fix CI errors
上级 6f7a80c3
...@@ -129,14 +129,13 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, ...@@ -129,14 +129,13 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
} }
} }
bool IsCommunicationOp(const Instruction& instr) { bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = { const std::set<std::string> special_comm_op_set = {
"send", "send",
"recv", "recv",
"send_v2", "send_v2",
"recv_v2", "recv_v2",
}; };
const std::string& op_name = instr.OpBase()->Type();
const std::string communication_op_prefix = "c_"; const std::string communication_op_prefix = "c_";
if (op_name.find(communication_op_prefix) != std::string::npos || if (op_name.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op_name)) { special_comm_op_set.count(op_name)) {
...@@ -145,6 +144,10 @@ bool IsCommunicationOp(const Instruction& instr) { ...@@ -145,6 +144,10 @@ bool IsCommunicationOp(const Instruction& instr) {
return false; return false;
} }
bool IsCommunicationOp(const Instruction& instr) {
return IsCommunicationOp(instr.OpBase()->Type());
}
bool IsCpuOp(const Instruction& instr) { bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace()); return platform::is_cpu_place(instr.DeviceContext().GetPlace());
} }
......
...@@ -65,6 +65,8 @@ class AsyncWorkQueue { ...@@ -65,6 +65,8 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_; std::unique_ptr<WorkQueueGroup> queue_group_;
}; };
bool IsCommunicationOp(const std::string& op_name);
bool IsCommunicationOp(const Instruction& instr); bool IsCommunicationOp(const Instruction& instr);
bool IsCpuOp(const Instruction& instr); bool IsCpuOp(const Instruction& instr);
......
...@@ -528,7 +528,12 @@ void InterpreterCore::Convert( ...@@ -528,7 +528,12 @@ void InterpreterCore::Convert(
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx]; auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); Priority priority =
interpreter::IsCommunicationOp(op_func_node.operator_base_->Type())
? Priority::kLowest
: Priority::kNormal;
vec_instruction_.emplace_back(
op_idx, std::move(op_func_node), *dev_ctx_, priority);
} }
BuildOperatorDependences(); BuildOperatorDependences();
...@@ -835,7 +840,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -835,7 +840,7 @@ void InterpreterCore::ExecuteInstructionList(
} }
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) { const Instruction& instr, std::deque<size_t>* reserved_next_ops) {
platform::RecordEvent record( platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10); "RunNextInstructions", platform::TracerEventType::UserDefined, 10);
auto& next_instr = instr.NextInstructions(); auto& next_instr = instr.NextInstructions();
...@@ -848,7 +853,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -848,7 +853,7 @@ void InterpreterCore::RunNextInstructions(
if (instr.KernelType() == OpFuncType::kQueueAsync) { if (instr.KernelType() == OpFuncType::kQueueAsync) {
// move all sync_ops into other threads // move all sync_ops into other threads
for (auto next_id : next_instr.SyncRunIds()) { for (size_t next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_->AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
...@@ -856,14 +861,22 @@ void InterpreterCore::RunNextInstructions( ...@@ -856,14 +861,22 @@ void InterpreterCore::RunNextInstructions(
} }
} }
// keep all async_ops running in current thread // keep all async_ops running in current thread
for (auto next_id : next_instr.DirectRunIds()) { for (size_t next_id : next_instr.DirectRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
reserved_next_ops->push(next_id); if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
} }
} }
for (auto next_id : next_instr.EventRunIds()) { }
for (size_t next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
reserved_next_ops->push(next_id); if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
}
} }
} }
} else { } else {
...@@ -895,16 +908,18 @@ void InterpreterCore::RunNextInstructions( ...@@ -895,16 +908,18 @@ void InterpreterCore::RunNextInstructions(
[this, next_id] { RunInstructionAsync(next_id); }); [this, next_id] { RunInstructionAsync(next_id); });
} }
} }
if (first_op != -1) reserved_next_ops->push(first_op); if (first_op != -1) {
reserved_next_ops->push_front(first_op);
}
} }
} }
void InterpreterCore::RunInstructionAsync(size_t instr_id) { void InterpreterCore::RunInstructionAsync(size_t instr_id) {
std::queue<size_t> ready_ops; std::deque<size_t> ready_ops;
ready_ops.push(instr_id); ready_ops.push_back(instr_id);
while (!ready_ops.empty()) { while (!ready_ops.empty()) {
instr_id = ready_ops.front(); instr_id = ready_ops.front();
ready_ops.pop(); ready_ops.pop_front();
auto& instr_node = vec_instruction_.at(instr_id); auto& instr_node = vec_instruction_.at(instr_id);
VLOG(5) << __func__ << " OP id:" << instr_node.Id() VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:" << " name:" << instr_node.OpBase()->Type() << " type:"
......
...@@ -92,7 +92,7 @@ class InterpreterCore { ...@@ -92,7 +92,7 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id); void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node); void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops); std::deque<size_t>* reserved_next_ops);
// only used when program contains no feed op // only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names, void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors, const std::vector<phi::DenseTensor>& feed_tensors,
......
...@@ -673,8 +673,12 @@ void VariableScope::CheckExist(const std::string& name) const { ...@@ -673,8 +673,12 @@ void VariableScope::CheckExist(const std::string& name) const {
Instruction::Instruction(size_t id, Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx) const platform::DeviceContext& dev_ctx,
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) { const Priority priority)
: id_(id),
op_func_node_(op_func_node),
dev_ctx_(dev_ctx),
priority_(priority) {
PADDLE_ENFORCE_GE(id, PADDLE_ENFORCE_GE(id,
0, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
......
...@@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream"; ...@@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream"; constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream"; constexpr const char* kH2DStream = "H2DStream";
enum class Priority { kLowest, kNormal };
class InterpretercoreInferShapeContext : public InferShapeContext { class InterpretercoreInferShapeContext : public InferShapeContext {
public: public:
InterpretercoreInferShapeContext(const OperatorBase& op, InterpretercoreInferShapeContext(const OperatorBase& op,
...@@ -300,7 +302,8 @@ class Instruction { ...@@ -300,7 +302,8 @@ class Instruction {
public: public:
Instruction(size_t id, Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx,
const Priority priority);
size_t Id() const; size_t Id() const;
...@@ -362,10 +365,13 @@ class Instruction { ...@@ -362,10 +365,13 @@ class Instruction {
std::shared_ptr<platform::DeviceEvent> event, std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type); platform::DeviceType waiter_type);
Priority GetPriority() const { return priority_; }
private: private:
size_t id_; size_t id_;
OpFuncNode op_func_node_; OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
const Priority priority_;
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册