未验证 提交 ca088f92 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Parse pipeline config (#37319)

上级 f11e843a
......@@ -27,4 +27,6 @@ message FleetExecutorDesc {
optional int32 dp_degree = 4 [ default = 1 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 pp_degree = 6 [ default = 1 ];
optional int64 num_micro_batches = 7 [ default = 1 ];
optional int64 num_slots = 8 [ default = 1 ];
}
......@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
......@@ -21,7 +22,8 @@ namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Start pooling local mailbox's thread.";
VLOG(3) << "Interceptor " << interceptor_id_
<< " starts the thread pooling it's local mailbox.";
PoolTheMailbox();
});
}
......
......@@ -96,6 +96,9 @@ class Interceptor {
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
};
class InterceptorFactory {
......
......@@ -136,16 +136,31 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
role_to_ops.at(new_op_role_id).emplace_back(op.get());
}
int64_t cur_rank = exe_desc_.cur_rank();
DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(),
exe_desc_.mp_degree());
const auto& coord = coord_sys.RankToCoord(cur_rank);
int pipeline_stage = coord.pp_idx;
int64_t num_pipeline_stages = exe_desc_.pp_degree();
// TODO(fleet_executor dev): start up steps should be a config `num_slots`
int64_t start_up_steps = num_pipeline_stages - pipeline_stage - 1;
int64_t num_micro_batches = exe_desc_.num_micro_batches();
int64_t task_id = cur_rank * functionality_order.size();
for (std::size_t i = 0; i < functionality_order.size(); ++i) {
OpRole role = functionality_order[i];
int64_t role_id = static_cast<int64_t>(role);
int64_t max_run_times = num_micro_batches;
int64_t max_slot_nums = start_up_steps;
if (IsLRSched(role_id) || IsOptimize(role_id)) {
max_run_times = 1;
max_slot_nums = 1;
}
if (role_to_ops.find(role_id) == role_to_ops.end()) {
task_nodes_.emplace_back(
TaskNode::CreateEmptyTaskNode(role_id, cur_rank, task_id));
task_nodes_.emplace_back(TaskNode::CreateEmptyTaskNode(
role_id, cur_rank, task_id, max_run_times, max_slot_nums));
} else {
task_nodes_.emplace_back(TaskNode::CreateTaskNode(
role_id, role_to_ops.at(role_id), cur_rank, task_id));
task_nodes_.emplace_back(
TaskNode::CreateTaskNode(role_id, role_to_ops.at(role_id), cur_rank,
task_id, max_run_times, max_slot_nums));
}
++task_id;
}
......
......@@ -22,22 +22,37 @@ using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(int64_t role, const std::vector<OperatorBase*>& ops,
int64_t rank, int64_t task_id)
: ops_(ops), role_(role), rank_(rank), task_id_(task_id) {}
int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id)
: role_(role), rank_(rank), task_id_(task_id) {}
TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id,
int64_t max_run_times, int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
std::unique_ptr<TaskNode> TaskNode::CreateEmptyTaskNode(int64_t role,
int64_t rank,
int64_t task_id) {
return std::make_unique<TaskNode>(role, rank, task_id);
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, rank, task_id, max_run_times,
max_slot_nums);
}
std::unique_ptr<TaskNode> TaskNode::CreateTaskNode(
int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id) {
return std::make_unique<TaskNode>(role, ops, rank, task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, ops, rank, task_id, max_run_times,
max_slot_nums);
}
void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); }
......
......@@ -28,23 +28,28 @@ namespace distributed {
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t role, int64_t rank, int64_t task_id);
TaskNode(int64_t role, int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default;
int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; }
int64_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
void AddUpstreamTask(int64_t task_id);
void AddDownstreamTask(int64_t task_id);
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int64_t role,
int64_t rank,
int64_t task_id);
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
static std::unique_ptr<TaskNode> CreateTaskNode(
int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
......@@ -55,6 +60,8 @@ class TaskNode final {
int64_t role_;
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
};
} // namespace distributed
......
......@@ -1981,6 +1981,8 @@ class Executor(object):
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
......
......@@ -43,7 +43,11 @@ class TestFleetExecutor(unittest.TestCase):
"mp_degree": 2,
"pp_degree": 2
}
fleet_opt = {"dist_strategy": strategy.sharding_configs}
strategy.pipeline_configs = {"accumulate_steps": 8}
fleet_opt = {
"dist_strategy": strategy.sharding_configs,
"num_micro_batches": strategy.pipeline_configs["accumulate_steps"]
}
if fluid.is_compiled_with_cuda():
self.run_fleet_executor(fluid.CUDAPlace(0), fleet_opt)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册