未验证 提交 f306965d 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Add amplify interceptor info runtime graph (#37783)

上级 cc2b4662
......@@ -27,28 +27,6 @@ AmplifierInterceptor::AmplifierInterceptor(int64_t interceptor_id,
run_at_offset_ = node->run_at_offset();
reply_up_per_steps_ = node->reply_up_per_steps();
send_down_per_steps_ = node->send_down_per_steps();
PADDLE_ENFORCE_GE(
run_per_steps_, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but now is %ld", run_per_steps_));
PADDLE_ENFORCE_GE(
run_at_offset_, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but now is %ld", run_at_offset_));
PADDLE_ENFORCE_LT(run_at_offset_, run_per_steps_,
platform::errors::InvalidArgument(
"run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
run_at_offset_, run_per_steps_));
PADDLE_ENFORCE_GE(
reply_up_per_steps_, 1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but now is %ld", reply_up_per_steps_));
PADDLE_ENFORCE_GE(send_down_per_steps_, 1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but now is %ld",
send_down_per_steps_));
}
void AmplifierInterceptor::RunOps() {
......
......@@ -199,6 +199,13 @@ void Carrier::CreateInterceptors() {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps()));
std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
......@@ -214,7 +221,7 @@ void Carrier::CreateInterceptors() {
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
<< " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
......
......@@ -161,7 +161,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops.";
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
}
......@@ -180,6 +181,8 @@ void ComputeInterceptor::Run() {
ReplyCompletedToUpStream();
// Try to stop Carrier
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
StopCarrier();
}
}
......
......@@ -161,22 +161,30 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
int64_t num_micro_batches = exe_desc_.num_micro_batches();
int64_t task_id = cur_rank * functionality_order.size();
for (std::size_t i = 0; i < functionality_order.size(); ++i) {
VLOG(3) << "Runtime graph is creating task node for: " << task_id << ".";
OpRole role = functionality_order[i];
int32_t role_id = static_cast<int64_t>(role);
int64_t max_run_times = num_micro_batches;
int64_t max_slot_nums = start_up_steps;
// NOTE: use short path, each interceptor should run for max_run_times
std::vector<OperatorBase*> task_ops{};
if (role_to_ops.find(role_id) != role_to_ops.end()) {
task_ops = role_to_ops.at(role_id);
}
std::unique_ptr<TaskNode> task_node = std::make_unique<TaskNode>(
role_id, task_ops, cur_rank, task_id, max_run_times, max_slot_nums);
if (IsLRSched(role_id) || IsOptimize(role_id)) {
max_run_times = 1;
max_slot_nums = 1;
task_node->SetType("Amplifier");
if (IsLRSched(role_id)) {
task_node->SetRunPerSteps(max_run_times);
} else {
task_node->SetRunAtOffset(max_run_times - 1);
task_node->SetRunPerSteps(max_run_times);
}
if (role_to_ops.find(role_id) == role_to_ops.end()) {
task_nodes_.emplace_back(TaskNode::CreateEmptyTaskNode(
role_id, cur_rank, task_id, max_run_times, max_slot_nums));
} else {
task_nodes_.emplace_back(
TaskNode::CreateTaskNode(role_id, role_to_ops.at(role_id), cur_rank,
task_id, max_run_times, max_slot_nums));
task_node->SetType("Compute");
}
task_nodes_.emplace_back(std::move(task_node));
++task_id;
}
}
......@@ -227,6 +235,8 @@ void RuntimeGraph::FakeDependence() {
void RuntimeGraph::AssignTaskToIntercepter() {
for (const auto& task : task_nodes_) {
int64_t intercepter_id = task->task_id();
VLOG(3) << "Runtime graph is assigning task to interceptor: "
<< intercepter_id << " with type: " << task->type() << ".";
if (intercepter_id_to_node_.find(intercepter_id) !=
intercepter_id_to_node_.end()) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -57,22 +57,6 @@ TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id,
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
std::unique_ptr<TaskNode> TaskNode::CreateEmptyTaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, rank, task_id, max_run_times,
max_slot_nums);
}
std::unique_ptr<TaskNode> TaskNode::CreateTaskNode(
int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, ops, rank, task_id, max_run_times,
max_slot_nums);
}
bool TaskNode::AddUpstreamTask(int64_t task_id) {
const auto& ret = upstream_.insert(task_id);
return *ret.first == task_id;
......@@ -92,5 +76,34 @@ std::string TaskNode::DebugString() const {
os << "\n";
return os.str();
}
void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value;
}
void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value;
}
void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1, platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value;
}
void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value, 1, platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value;
}
} // namespace distributed
} // namespace paddle
......@@ -54,25 +54,16 @@ class TaskNode final {
const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }
void SetRunPerSteps(int64_t value) { run_per_steps_ = value; }
void SetRunAtOffset(int64_t value) { run_at_offset_ = value; }
void SetReplyUpPerSteps(int64_t value) { reply_up_per_steps_ = value; }
void SetSendDownPerSteps(int64_t value) { send_down_per_steps_ = value; }
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
void SetSendDownPerSteps(int64_t value);
void SetType(const std::string& type) { type_ = type; }
bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id);
std::string DebugString() const;
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
static std::unique_ptr<TaskNode> CreateTaskNode(
int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
TaskNode() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册