未验证 提交 6b48dfe9 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add interceptor gc (#37889)

......@@ -13,7 +13,7 @@ endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
${BRPC_DEPS})
executor_gc_helper ${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
......@@ -191,49 +192,70 @@ void Carrier::HandleTmpMessages() {
message_tmp_.clear();
}
static std::shared_ptr<framework::GarbageCollector> GetGC(
const platform::Place& place) {
int64_t max_memory_size = framework::GetEagerDeletionThreshold();
std::shared_ptr<framework::GarbageCollector> gc;
if (max_memory_size >= 0) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place)) {
if (framework::IsFastEagerDeletionModeEnabled()) {
gc.reset(new framework::UnsafeFastGPUGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place), max_memory_size));
}
}
#endif
} // max_memory_size >= 0
return gc;
}
void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return;
auto gc = GetGC(place_);
// create each Interceptor
if (!(runtime_graph_->intercepter_id_to_node().empty())) {
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps()));
std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
interceptor.reset(new Interceptor(interceptor_id, task_node));
} else {
interceptor = InterceptorFactory::Create(task_node->type(),
interceptor_id, task_node);
}
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << ".";
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
platform::errors::InvalidArgument(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
task_node->run_at_offset(), task_node->run_per_steps()));
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
interceptor.reset(new Interceptor(interceptor_id, task_node));
} else {
interceptor = InterceptorFactory::Create(task_node->type(),
interceptor_id, task_node);
}
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
interceptor->SetGC(gc);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< " with type: " << task_node->type() << ".";
if (task_node->upstream().empty()) {
source_interceptor_ids_.emplace_back(interceptor_id);
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_.lock();
creating_interceptors_ = false;
creating_flag_mutex_.unlock();
HandleTmpMessages();
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_.lock();
creating_interceptors_ = false;
creating_flag_mutex_.unlock();
HandleTmpMessages();
}
} // namespace distributed
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
......@@ -172,6 +173,11 @@ void ComputeInterceptor::RunOps() {
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
if (gc_) {
framework::DeleteUnusedTensors(
*microbatch_scopes_[step_ % node_->max_run_times()], op,
node_->unused_vars(), gc_.get());
}
}
}
......
......@@ -31,6 +31,7 @@
namespace paddle {
namespace framework {
class Scope;
class GarbageCollector;
}
namespace distributed {
......@@ -73,6 +74,9 @@ class Interceptor {
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes;
}
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc;
}
TaskNode* GetTaskNode() const { return node_; }
......@@ -94,6 +98,7 @@ class Interceptor {
framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
private:
// pool the local mailbox, parse the Message
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -101,16 +102,7 @@ RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) {
if (exe_desc.pp_degree() == 1) {
int64_t cur_rank = exe_desc_.cur_rank();
int64_t max_run_times = exe_desc_.num_micro_batches();
int64_t max_slot_nums = exe_desc_.num_slots();
auto task_node = std::make_unique<TaskNode>(program, cur_rank,
max_run_times, max_slot_nums);
task_node->SetType("Compute");
task_nodes_.emplace_back(std::move(task_node));
int64_t task_id = task_nodes_[0]->task_id();
intercepter_id_to_rank_.insert({task_id, cur_rank});
intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
OriginProgramCompile(program);
} else {
SplitProgramBasedFunctionality(program);
AssignTaskToIntercepter();
......@@ -119,10 +111,31 @@ RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
}
}
void RuntimeGraph::OriginProgramCompile(const ProgramDesc& program) {
int64_t cur_rank = exe_desc_.cur_rank();
int64_t max_run_times = exe_desc_.num_micro_batches();
int64_t max_slot_nums = exe_desc_.num_slots();
auto task_node = std::make_unique<TaskNode>(program, cur_rank, max_run_times,
max_slot_nums);
// TODO(wangxi): add skip vars
auto unused_vars =
framework::GetUnusedVars(program.Block(0), task_node->unique_ops(), {});
task_node->SetType("Compute");
task_node->SetUnusedVars(unused_vars);
task_nodes_.emplace_back(std::move(task_node));
int64_t task_id = task_nodes_[0]->task_id();
intercepter_id_to_rank_.insert({task_id, cur_rank});
intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
}
void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
for (const auto& op_desc : program.Block(0).AllOps()) {
ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
}
// TODO(wangxi): how to gc pipeline backward send
auto unused_vars = framework::GetUnusedVars(program.Block(0), ops_, {});
std::unordered_map<int32_t, std::vector<OperatorBase*>> role_to_ops;
for (const auto& op : ops_) {
......@@ -183,6 +196,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
} else {
task_node->SetType("Compute");
}
task_node->SetUnusedVars(unused_vars);
task_nodes_.emplace_back(std::move(task_node));
++task_id;
}
......
......@@ -52,6 +52,7 @@ class RuntimeGraph final {
void FakeDependence();
void AssignTaskToIntercepter();
void FakeRuntimeInfo();
void OriginProgramCompile(const ProgramDesc& program);
// LRSched, Forward, Backward, Optimize
static std::vector<paddle::framework::OpRole> functionality_order;
std::vector<std::unique_ptr<TaskNode>> task_nodes_;
......
......@@ -57,12 +57,24 @@ class TaskNode final {
const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const {
return unused_vars_;
}
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
void SetSendDownPerSteps(int64_t value);
void SetType(const std::string& type) { type_ = type; }
void SetUnusedVars(
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars) {
unused_vars_ = unused_vars;
}
// upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
......@@ -79,6 +91,9 @@ class TaskNode final {
std::unordered_map<int64_t, int64_t> downstream_;
framework::ProgramDesc program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部