未验证 提交 df207283 编写于 作者: R Ruibiao Chen 提交者: GitHub

Use StandaloneExecutor in FleetExecutor (#50239)

* Use StandaloneExecutor in FleetExecutor

* Update FLAGS

* Fix CI errors

* Update code

* Add force_root_scope_vars config

* Update code

* Fix CI errors

* Fix test_layer_new errors
上级 905cefd4
...@@ -44,6 +44,7 @@ cc_library( ...@@ -44,6 +44,7 @@ cc_library(
message_bus.cc message_bus.cc
dist_model_tensor_wrapper.cc dist_model_tensor_wrapper.cc
DEPS proto_desc DEPS proto_desc
standalone_executor
fleet_executor_desc_proto fleet_executor_desc_proto
interceptor_message_proto interceptor_message_proto
task_loop_thread_pool task_loop_thread_pool
......
...@@ -28,6 +28,13 @@ ...@@ -28,6 +28,13 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
PADDLE_DEFINE_EXPORTED_bool(
fleet_executor_with_standalone,
false,
"Use standalone executor to run ops. Temporary FLAGS, will be removed "
"after all fleet executor cases are modified to run ops with standalone "
"executor.");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -95,7 +102,7 @@ void Carrier::Init( ...@@ -95,7 +102,7 @@ void Carrier::Init(
thread_pool_.SetThreadNum(thread_num_); thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start(); thread_pool_.Start();
CreateInterceptors(); CreateInterceptors(inference_root_scope_vars);
is_init_ = true; is_init_ = true;
} }
...@@ -279,7 +286,8 @@ static std::shared_ptr<framework::GarbageCollector> GetGC( ...@@ -279,7 +286,8 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
return gc; return gc;
} }
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors(
const std::vector<std::string>& inference_root_scope_vars) {
if (interceptor_id_to_node_.empty()) return; if (interceptor_id_to_node_.empty()) return;
auto gc = GetGC(place_); auto gc = GetGC(place_);
...@@ -343,7 +351,48 @@ void Carrier::CreateInterceptors() { ...@@ -343,7 +351,48 @@ void Carrier::CreateInterceptors() {
interceptor->SetMiniBatchScope(minibatch_scope_); interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_); interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_); interceptor->SetRootScope(root_scope_);
interceptor->SetGC(gc);
if (FLAGS_fleet_executor_with_standalone &&
(task_node->type() == "Amplifier" || task_node->type() == "Compute")) {
std::vector<std::shared_ptr<InterpreterCore>> cores;
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.force_root_scope_vars = std::set<std::string>(
inference_root_scope_vars.begin(), inference_root_scope_vars.end());
const framework::ProgramDesc* program = task_node->program();
PADDLE_ENFORCE_NOT_NULL(
program,
phi::errors::InvalidArgument("TaskNode %d's program is not set.",
interceptor_id));
std::vector<framework::VarDesc*> all_vars = program->Block(0).AllVars();
for (framework::VarDesc* var : all_vars) {
execution_config.skip_gc_vars.insert(var->Name());
}
// ONLY unused vars can be GCed.
const std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>& unused_vars =
task_node->unused_vars();
for (auto& item : unused_vars) {
for (const std::string& unused_var : item.second) {
execution_config.skip_gc_vars.erase(unused_var);
}
}
for (framework::Scope* scope : microbatch_scopes_) {
cores.push_back(std::make_shared<InterpreterCore>(
place_, task_node->program()->Block(0), scope, execution_config));
}
for (size_t i = 1; i < cores.size(); ++i) {
cores[i]->ShareWorkQueueFrom(cores[i - 1]);
}
interceptor->SetInterpreterCore(cores);
} else {
interceptor->SetGC(gc);
}
SetInterceptor(interceptor_id, std::move(interceptor)); SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
......
...@@ -94,7 +94,8 @@ class Carrier final { ...@@ -94,7 +94,8 @@ class Carrier final {
Carrier() = delete; Carrier() = delete;
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors(
const std::vector<std::string>& inference_root_scope_vars = {});
int64_t GetRank(int64_t interceptor_id) const; int64_t GetRank(int64_t interceptor_id) const;
......
...@@ -187,7 +187,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -187,7 +187,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
for (auto op : node_->ops()) { if (!cores_.empty() || !node_->ops().empty()) {
PADDLE_ENFORCE_LT(cur_scope_id_, PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(), microbatch_scopes_.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -195,12 +195,19 @@ void ComputeInterceptor::RunOps() { ...@@ -195,12 +195,19 @@ void ComputeInterceptor::RunOps() {
"microbatch_scopes, but recevice scope index %ld", "microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(), microbatch_scopes_.size(),
cur_scope_id_)); cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_); }
if (gc_) {
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_], if (!cores_.empty()) {
op, cores_[cur_scope_id_]->Run(/*feed_names=*/{}, /*need_fetch=*/false);
node_->unused_vars(), } else {
gc_.get()); for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) {
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
op,
node_->unused_vars(),
gc_.get());
}
} }
} }
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -40,6 +41,8 @@ class TaskNode; ...@@ -40,6 +41,8 @@ class TaskNode;
class Carrier; class Carrier;
class TaskLoop; class TaskLoop;
using InterpreterCore = framework::InterpreterCore;
constexpr int64_t SOURCE_ID = -1; constexpr int64_t SOURCE_ID = -1;
constexpr int64_t SINK_ID = -2; constexpr int64_t SINK_ID = -2;
...@@ -75,6 +78,10 @@ class Interceptor { ...@@ -75,6 +78,10 @@ class Interceptor {
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) { void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes; microbatch_scopes_ = scopes;
} }
void SetInterpreterCore(
const std::vector<std::shared_ptr<InterpreterCore>> cores) {
cores_ = cores;
}
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) { void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc; gc_ = gc;
} }
...@@ -100,6 +107,7 @@ class Interceptor { ...@@ -100,6 +107,7 @@ class Interceptor {
framework::Scope* root_scope_{nullptr}; framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr}; framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{}; std::vector<framework::Scope*> microbatch_scopes_{};
std::vector<std::shared_ptr<InterpreterCore>> cores_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr}; std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_; Carrier* carrier_;
......
...@@ -128,7 +128,7 @@ class TaskNode final { ...@@ -128,7 +128,7 @@ class TaskNode final {
// task_id-->type // task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_; std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_{nullptr};
std::string cond_var_; std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>> std::unordered_map<const OperatorBase*, std::vector<std::string>>
......
...@@ -303,12 +303,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache( ...@@ -303,12 +303,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
"all cache!"; "all cache!";
interpretercore_info_cache.Finalize(); interpretercore_info_cache.Finalize();
} }
interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
auto core = std::make_shared<InterpreterCore>( auto core = std::make_shared<InterpreterCore>(
place, place, program_desc.Block(0), scope, execution_config);
program_desc.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
scope,
/*used_for_jit=*/true);
auto &cached_value = auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad); interpretercore_info_cache.GetMutable(program_id, is_grad);
cached_value.core_ = core; cached_value.core_ = core;
......
...@@ -118,21 +118,43 @@ inline std::tuple<int, int> GetThreadPoolConfig(const phi::Place& place, ...@@ -118,21 +118,43 @@ inline std::tuple<int, int> GetThreadPoolConfig(const phi::Place& place,
return std::make_tuple(num_host_threads, num_device_threads); return std::make_tuple(num_host_threads, num_device_threads);
} }
ExecutionConfig::ExecutionConfig(const phi::Place& place, size_t op_num) { void ExecutionConfig::AnalyzeThreadPoolConfig(const phi::Place& place,
std::tie(host_num_threads, deivce_num_threads) = size_t op_num) {
GetThreadPoolConfig(place, op_num); if (host_num_threads == 0 || device_num_threads == 0) {
std::tie(host_num_threads, device_num_threads) =
GetThreadPoolConfig(place, op_num);
}
} }
void ExecutionConfig::Log(int log_level) { void ExecutionConfig::Log(int log_level) {
VLOG(log_level) << "ExecutionConfig:"; std::stringstream log_str;
VLOG(log_level) << "used_for_jit = " << used_for_jit; log_str << "ExecutionConfig:\n"
VLOG(log_level) << "create_local_scope = " << create_local_scope; << "create_local_scope = " << create_local_scope << "\n"
VLOG(log_level) << "host_num_threads = " << host_num_threads; << "used_for_cinn = " << used_for_cinn << "\n"
VLOG(log_level) << "deivce_num_threads = " << deivce_num_threads; << "used_for_control_flow_op = " << used_for_control_flow_op << "\n"
VLOG(log_level) << "skip_gc_vars = "; << "used_for_jit = " << used_for_jit << "\n"
<< "deivce_num_threads = " << device_num_threads << "\n"
<< "host_num_threads = " << host_num_threads << "\n";
log_str << "force_root_scope_vars = [";
for (const std::string& var : force_root_scope_vars) {
log_str << var << " ";
}
log_str << "]\n";
log_str << "jit_input_vars = [";
for (const std::string& var : jit_input_vars) {
log_str << var << " ";
}
log_str << "]\n";
log_str << "skip_gc_vars = [";
for (const std::string& var : skip_gc_vars) { for (const std::string& var : skip_gc_vars) {
VLOG(log_level) << var; log_str << var << " ";
} }
log_str << "]\n";
VLOG(log_level) << log_str.str();
} }
} // namespace interpreter } // namespace interpreter
......
...@@ -24,18 +24,20 @@ namespace framework { ...@@ -24,18 +24,20 @@ namespace framework {
namespace interpreter { namespace interpreter {
struct ExecutionConfig { struct ExecutionConfig {
bool used_for_jit{false}; bool create_local_scope{true};
bool used_for_cinn{false}; bool used_for_cinn{false};
bool used_for_control_flow_op{false}; bool used_for_control_flow_op{false};
bool used_for_jit{false};
bool create_local_scope{true}; size_t device_num_threads{0};
size_t host_num_threads; size_t host_num_threads{0};
size_t deivce_num_threads;
std::set<std::string> skip_gc_vars; std::set<std::string> force_root_scope_vars;
std::set<std::string> jit_input_vars; std::set<std::string> jit_input_vars;
std::set<std::string> skip_gc_vars;
ExecutionConfig(const phi::Place& place, size_t op_num); void AnalyzeThreadPoolConfig(const phi::Place& place, size_t op_num);
void Log(int log_level); void Log(int log_level);
}; };
......
...@@ -276,15 +276,16 @@ GetUnusedVars(const BlockDesc& block, ...@@ -276,15 +276,16 @@ GetUnusedVars(const BlockDesc& block,
} }
void BuildVariableScope(const framework::BlockDesc& block, void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope, const ExecutionConfig& execution_config,
bool use_local_scope) { VariableScope* var_scope) {
VLOG(3) << "Creating Variables"; VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope(); auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope. // created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = execution_config.create_local_scope
: var_scope->GetMutableScope(); ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) { for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name(); auto var_name = var_desc->Name();
...@@ -295,7 +296,8 @@ void BuildVariableScope(const framework::BlockDesc& block, ...@@ -295,7 +296,8 @@ void BuildVariableScope(const framework::BlockDesc& block,
continue; continue;
} }
if (var_desc->Persistable()) { if (var_desc->Persistable() ||
execution_config.force_root_scope_vars.count(var_name)) {
// In principle, we should put all trainable parameters in global scope, // In principle, we should put all trainable parameters in global scope,
// which means the root of the scope tree. Some cases like quantization // which means the root of the scope tree. Some cases like quantization
// will look up these parameters in global scope. // will look up these parameters in global scope.
...@@ -305,7 +307,6 @@ void BuildVariableScope(const framework::BlockDesc& block, ...@@ -305,7 +307,6 @@ void BuildVariableScope(const framework::BlockDesc& block,
} }
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name); auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right, // NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable. // InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType()); InitializeVariable(ptr, var_desc->GetType());
...@@ -315,8 +316,7 @@ void BuildVariableScope(const framework::BlockDesc& block, ...@@ -315,8 +316,7 @@ void BuildVariableScope(const framework::BlockDesc& block,
auto* ptr = local_scope->Var(var_name); auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType()); InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type " << ptr << " type is " << static_cast<int>(var_desc->GetType());
<< static_cast<int>(var_desc->GetType());
} }
var_scope->AddVar(var_name, var_desc); var_scope->AddVar(var_name, var_desc);
} }
......
...@@ -91,8 +91,8 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -91,8 +91,8 @@ bool BuildOpFuncList(const platform::Place& place,
bool use_local_scope = true); bool use_local_scope = true);
void BuildVariableScope(const framework::BlockDesc& block, void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope, const ExecutionConfig& execution_config,
bool use_local_scope = true); VariableScope* var_scope);
void LogDeviceMemoryStats(const platform::Place& place); void LogDeviceMemoryStats(const platform::Place& place);
......
...@@ -106,32 +106,24 @@ inline void SetDeviceId(const platform::Place& place) { ...@@ -106,32 +106,24 @@ inline void SetDeviceId(const platform::Place& place) {
} }
} }
// TODO(Ruibiao): Pass skip_gc_vars, used_for_jit, and other config messages by
// constructing an interpreter::ExecutionConfig
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
framework::Scope* scope, framework::Scope* scope,
bool used_for_jit, const ExecutionConfig& execution_config)
bool used_for_control_flow_op,
bool used_for_cinn)
: place_(place), : place_(place),
block_(block), block_(block),
execution_config_(place, block.OpSize()),
stream_analyzer_(place), stream_analyzer_(place),
execution_config_(execution_config),
var_scope_(scope) { var_scope_(scope) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_; VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
execution_config_.used_for_jit = used_for_jit; if (!FLAGS_new_executor_use_local_scope) {
execution_config_.used_for_cinn = used_for_cinn; execution_config_.create_local_scope = false;
execution_config_.used_for_control_flow_op = used_for_control_flow_op; }
execution_config_.create_local_scope = execution_config_.AnalyzeThreadPoolConfig(place, block.OpSize());
!used_for_jit && FLAGS_new_executor_use_local_scope &&
!used_for_control_flow_op && !used_for_cinn;
execution_config_.skip_gc_vars = skip_gc_vars;
execution_config_.Log(/*log_level=*/8); execution_config_.Log(/*log_level=*/8);
if (execution_config_.create_local_scope) { if (execution_config_.create_local_scope) {
...@@ -280,7 +272,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -280,7 +272,7 @@ paddle::framework::FetchList InterpreterCore::Run(
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::BuildVariableScope( paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope()); block_, execution_config_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList( auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
...@@ -410,7 +402,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() { ...@@ -410,7 +402,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
if (async_work_queue_ == nullptr) { if (async_work_queue_ == nullptr) {
async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>( async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
execution_config_.host_num_threads, execution_config_.host_num_threads,
execution_config_.deivce_num_threads, execution_config_.device_num_threads,
nullptr); nullptr);
} }
return async_work_queue_; return async_work_queue_;
...@@ -1271,7 +1263,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1271,7 +1263,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::BuildVariableScope( paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope()); block_, execution_config_, &var_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList( auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
...@@ -1309,24 +1301,6 @@ void InterpreterCore::SetFeedVarsInplaceSkip( ...@@ -1309,24 +1301,6 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; } bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; }
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
Scope* scope,
const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::AddFetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
core->SetCopyProgram(new_prog);
return core;
}
// Note(zhangbo): // Note(zhangbo):
// (1) What is "Trace"? // (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a // The OP execute scheduling rule adopted by Interpretercore by default is a
...@@ -1462,5 +1436,24 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() { ...@@ -1462,5 +1436,24 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
trace_execute_order_ = trace_order; trace_execute_order_ = trace_order;
} }
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
Scope* scope,
const std::vector<std::string>& fetch_names,
const interpreter::ExecutionConfig& execution_config) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::AddFetch(fetch_names, block);
core =
std::make_shared<InterpreterCore>(place, *block, scope, execution_config);
core->SetCopyProgram(new_prog);
return core;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -40,14 +40,18 @@ namespace paddle { ...@@ -40,14 +40,18 @@ namespace paddle {
namespace framework { namespace framework {
class InterpreterCore { class InterpreterCore {
using ExecutionConfig = interpreter::ExecutionConfig;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
public: public:
InterpreterCore(const platform::Place& place, InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
Scope* scope, Scope* scope,
bool used_for_jit = false, const ExecutionConfig& execution_config = ExecutionConfig());
bool used_for_control_flow_op = false,
bool used_for_cinn = false);
~InterpreterCore(); ~InterpreterCore();
...@@ -79,12 +83,7 @@ class InterpreterCore { ...@@ -79,12 +83,7 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; } const platform::Place& GetPlace() const { return place_; }
private: private:
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>; DISABLE_COPY_AND_ASSIGN(InterpreterCore);
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
// build graph // build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences(); void BuildOperatorDependences();
...@@ -135,11 +134,10 @@ class InterpreterCore { ...@@ -135,11 +134,10 @@ class InterpreterCore {
private: private:
bool is_build_{false}; bool is_build_{false};
platform::Place place_; const platform::Place place_;
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
interpreter::DependencyBuilder dependency_builder_; interpreter::DependencyBuilder dependency_builder_;
interpreter::ExecutionConfig execution_config_;
interpreter::StreamAnalyzer stream_analyzer_; interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will // NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
...@@ -156,6 +154,9 @@ class InterpreterCore { ...@@ -156,6 +154,9 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0}; std::atomic<size_t> unfinished_op_number_{0};
ExecutionConfig execution_config_;
VariableScope var_scope_; VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
...@@ -189,9 +190,10 @@ class InterpreterCore { ...@@ -189,9 +190,10 @@ class InterpreterCore {
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const platform::Place& place,
const ProgramDesc& prog, const ProgramDesc& prog,
Scope* global_scope, Scope* scope,
const std::vector<std::string>& fetch_names = {}, const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {}); const interpreter::ExecutionConfig& execution_config =
interpreter::ExecutionConfig());
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -71,11 +71,7 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -71,11 +71,7 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (add_fetch_op) { if (add_fetch_op) {
core = CreateInterpreterCore(place_, prog, scope, fetch_names); core = CreateInterpreterCore(place_, prog, scope, fetch_names);
} else { } else {
core = std::make_shared<InterpreterCore>( core = std::make_shared<InterpreterCore>(place_, prog.Block(0), scope);
place_,
prog.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
scope);
} }
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
return core; return core;
......
...@@ -188,8 +188,11 @@ TEST(InterpreterCore, skip_gc_vars) { ...@@ -188,8 +188,11 @@ TEST(InterpreterCore, skip_gc_vars) {
"elementwise_add_0.tmp_0", "elementwise_add_0.tmp_0",
"tmp_0"}; "tmp_0"};
std::shared_ptr<InterpreterCore> main_core = interpreter::ExecutionConfig execution_config;
CreateInterpreterCore(place, main_prog, &scope, {}, skip_gc_vars); execution_config.skip_gc_vars = skip_gc_vars;
std::shared_ptr<InterpreterCore> main_core = CreateInterpreterCore(
place, main_prog, &scope, /*fetch_names=*/{}, execution_config);
auto check_gc_result = auto check_gc_result =
[](Scope& scope, std::set<std::string>& vars, bool is_skip_gc) { [](Scope& scope, std::set<std::string>& vars, bool is_skip_gc) {
......
...@@ -59,18 +59,17 @@ void InterpreterEngine::CreateInterpreterCore() { ...@@ -59,18 +59,17 @@ void InterpreterEngine::CreateInterpreterCore() {
GraphToProgram(graph, &converted_prog_, nullptr); GraphToProgram(graph, &converted_prog_, nullptr);
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
auto in_names = info_->InputArgNames(); auto in_names = info_->InputArgNames();
auto out_names = info_->OutputArgNames(); auto out_names = info_->OutputArgNames();
std::set<std::string> skip_gc_vars; execution_config.skip_gc_vars.insert(in_names.begin(), in_names.end());
skip_gc_vars.insert(in_names.begin(), in_names.end()); execution_config.skip_gc_vars.insert(out_names.begin(), out_names.end());
skip_gc_vars.insert(out_names.begin(), out_names.end());
inner_interpreter_ = std::make_shared<InterpreterCore>(
inner_interpreter_ = place_, converted_prog_.Block(0), &scope_, execution_config);
std::make_shared<InterpreterCore>(place_,
converted_prog_.Block(0),
/*skip_gc_vars=*/skip_gc_vars,
&scope_,
/*used_for_jit=*/true);
} }
std::vector<Tensor> InterpreterEngine::operator()( std::vector<Tensor> InterpreterEngine::operator()(
......
...@@ -488,14 +488,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore( ...@@ -488,14 +488,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
framework::proto::VarType::LOD_TENSOR); framework::proto::VarType::LOD_TENSOR);
} }
if (!interpreter_core_) { if (!interpreter_core_) {
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_cinn = true;
execution_config.skip_gc_vars = skip_gc_vars_;
interpreter_core_ = std::make_unique<framework::InterpreterCore>( interpreter_core_ = std::make_unique<framework::InterpreterCore>(
place, place, runtime_program_desc_->Block(0), scope, execution_config);
runtime_program_desc_->Block(0),
skip_gc_vars_,
scope,
/*used_for_jit*/ false,
/*used_for_control_flow_op*/ false,
/*used_for_cinn*/ true);
} else { } else {
interpreter_core_->reset_scope(scope); interpreter_core_->reset_scope(scope);
} }
......
...@@ -99,17 +99,18 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -99,17 +99,18 @@ class ConditionalBlockOp : public ConditionalOp {
LOG_FIRST_N(INFO, 1) LOG_FIRST_N(INFO, 1)
<< "[ControlFlow][ConditionalBlock] New Executor is Running."; << "[ControlFlow][ConditionalBlock] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) { if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(),
skip_vars.end());
VLOG(10) << "[interpreterCore cache]" << core_.get(); VLOG(10) << "[interpreterCore cache]" << core_.get();
VLOG_IF(10, core_) VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place); << platform::is_same_place(core_->GetPlace(), dev_place);
core_.reset(new InterpreterCore(dev_place,
*block, framework::interpreter::ExecutionConfig execution_config;
skip_gc_vars, execution_config.create_local_scope = false;
&cur_scope, execution_config.used_for_control_flow_op = true;
/* used_for_jit */ false, execution_config.skip_gc_vars =
/* used_for_control_flow_op */ true)); std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new InterpreterCore(
dev_place, *block, &cur_scope, execution_config));
VLOG(10) << "[interpreterCore cache]" VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core_; << "new created:" << core_;
} else { } else {
...@@ -214,14 +215,15 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -214,14 +215,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
VLOG(10) << "[interpreterCore cache]" << core_.get(); VLOG(10) << "[interpreterCore cache]" << core_.get();
VLOG_IF(10, core_) VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place); << platform::is_same_place(core_->GetPlace(), dev_place);
std::set<std::string> skip_gc_vars(inside_grads.begin(),
inside_grads.end()); framework::interpreter::ExecutionConfig execution_config;
core_.reset(new InterpreterCore(dev_place, execution_config.create_local_scope = false;
*block, execution_config.used_for_control_flow_op = true;
skip_gc_vars, execution_config.skip_gc_vars =
&cur_scope, std::set<std::string>(inside_grads.begin(), inside_grads.end());
/* used_for_jit */ false,
/* used_for_control_flow_op */ true)); core_.reset(new InterpreterCore(
dev_place, *block, &cur_scope, execution_config));
VLOG(10) << "[interpreterCore cache]" VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core_; << "new created:" << core_;
} else { } else {
......
...@@ -202,16 +202,16 @@ class WhileOp : public framework::OperatorBase { ...@@ -202,16 +202,16 @@ class WhileOp : public framework::OperatorBase {
if (FLAGS_control_flow_use_new_executor) { if (FLAGS_control_flow_use_new_executor) {
LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running."; LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) { if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore // initialize InterpreterCore
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new framework::InterpreterCore( core_.reset(new framework::InterpreterCore(
dev_place, dev_place, *block, &placeholder, execution_config));
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
} }
} else { } else {
if (!executor_ || if (!executor_ ||
...@@ -398,13 +398,14 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -398,13 +398,14 @@ class WhileGradOp : public framework::OperatorBase {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end()); std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore // initialize InterpreterCore
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new framework::InterpreterCore( core_.reset(new framework::InterpreterCore(
dev_place, dev_place, *block, &placeholder, execution_config));
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
} }
} else { } else {
if (!executor_ || if (!executor_ ||
......
...@@ -1624,6 +1624,12 @@ class Executor: ...@@ -1624,6 +1624,12 @@ class Executor:
if "fleet_opt" in program._pipeline_opt: if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program # Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None: if self._fleet_executor is None:
# Temporary manual enable standalone executor for fleet executor,
# delete this code after the FLAGS is removed.
if 'tasks' in program._pipeline_opt["fleet_opt"]:
set_flags(
{"FLAGS_fleet_executor_with_standalone": True}
)
self._fleet_executor = _prepare_fleet_executor() self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor( return self._run_using_fleet_executor(
program=program, program=program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册