未验证 提交 df207283 编写于 作者: R Ruibiao Chen 提交者: GitHub

Use StandaloneExecutor in FleetExecutor (#50239)

* Use StandaloneExecutor in FleetExecutor

* Update FLAGS

* Fix CI errors

* Update code

* Add force_root_scope_vars config

* Update code

* Fix CI errors

* Fix test_layer_new errors
上级 905cefd4
......@@ -44,6 +44,7 @@ cc_library(
message_bus.cc
dist_model_tensor_wrapper.cc
DEPS proto_desc
standalone_executor
fleet_executor_desc_proto
interceptor_message_proto
task_loop_thread_pool
......
......@@ -28,6 +28,13 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
PADDLE_DEFINE_EXPORTED_bool(
fleet_executor_with_standalone,
false,
"Use standalone executor to run ops. Temporary FLAGS, will be removed "
"after all fleet executor cases are modified to run ops with standalone "
"executor.");
namespace paddle {
namespace distributed {
......@@ -95,7 +102,7 @@ void Carrier::Init(
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
CreateInterceptors();
CreateInterceptors(inference_root_scope_vars);
is_init_ = true;
}
......@@ -279,7 +286,8 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
return gc;
}
void Carrier::CreateInterceptors() {
void Carrier::CreateInterceptors(
const std::vector<std::string>& inference_root_scope_vars) {
if (interceptor_id_to_node_.empty()) return;
auto gc = GetGC(place_);
......@@ -343,7 +351,48 @@ void Carrier::CreateInterceptors() {
interceptor->SetMiniBatchScope(minibatch_scope_);
interceptor->SetMicroBatchScope(microbatch_scopes_);
interceptor->SetRootScope(root_scope_);
interceptor->SetGC(gc);
if (FLAGS_fleet_executor_with_standalone &&
(task_node->type() == "Amplifier" || task_node->type() == "Compute")) {
std::vector<std::shared_ptr<InterpreterCore>> cores;
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.force_root_scope_vars = std::set<std::string>(
inference_root_scope_vars.begin(), inference_root_scope_vars.end());
const framework::ProgramDesc* program = task_node->program();
PADDLE_ENFORCE_NOT_NULL(
program,
phi::errors::InvalidArgument("TaskNode %d's program is not set.",
interceptor_id));
std::vector<framework::VarDesc*> all_vars = program->Block(0).AllVars();
for (framework::VarDesc* var : all_vars) {
execution_config.skip_gc_vars.insert(var->Name());
}
// ONLY unused vars can be GCed.
const std::unordered_map<const framework::OperatorBase*,
std::vector<std::string>>& unused_vars =
task_node->unused_vars();
for (auto& item : unused_vars) {
for (const std::string& unused_var : item.second) {
execution_config.skip_gc_vars.erase(unused_var);
}
}
for (framework::Scope* scope : microbatch_scopes_) {
cores.push_back(std::make_shared<InterpreterCore>(
place_, task_node->program()->Block(0), scope, execution_config));
}
for (size_t i = 1; i < cores.size(); ++i) {
cores[i]->ShareWorkQueueFrom(cores[i - 1]);
}
interceptor->SetInterpreterCore(cores);
} else {
interceptor->SetGC(gc);
}
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
......
......@@ -94,7 +94,8 @@ class Carrier final {
Carrier() = delete;
// create each Interceptor
void CreateInterceptors();
void CreateInterceptors(
const std::vector<std::string>& inference_root_scope_vars = {});
int64_t GetRank(int64_t interceptor_id) const;
......
......@@ -187,7 +187,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void ComputeInterceptor::RunOps() {
for (auto op : node_->ops()) {
if (!cores_.empty() || !node_->ops().empty()) {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
......@@ -195,12 +195,19 @@ void ComputeInterceptor::RunOps() {
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) {
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
op,
node_->unused_vars(),
gc_.get());
}
if (!cores_.empty()) {
cores_[cur_scope_id_]->Run(/*feed_names=*/{}, /*need_fetch=*/false);
} else {
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[cur_scope_id_], place_);
if (gc_) {
framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_],
op,
node_->unused_vars(),
gc_.get());
}
}
}
}
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
......@@ -40,6 +41,8 @@ class TaskNode;
class Carrier;
class TaskLoop;
using InterpreterCore = framework::InterpreterCore;
constexpr int64_t SOURCE_ID = -1;
constexpr int64_t SINK_ID = -2;
......@@ -75,6 +78,10 @@ class Interceptor {
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes;
}
void SetInterpreterCore(
const std::vector<std::shared_ptr<InterpreterCore>> cores) {
cores_ = cores;
}
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc;
}
......@@ -100,6 +107,7 @@ class Interceptor {
framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{};
std::vector<std::shared_ptr<InterpreterCore>> cores_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_;
......
......@@ -128,7 +128,7 @@ class TaskNode final {
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_;
framework::ProgramDesc* program_{nullptr};
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
......
......@@ -303,12 +303,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
"all cache!";
interpretercore_info_cache.Finalize();
}
interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
auto core = std::make_shared<InterpreterCore>(
place,
program_desc.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
scope,
/*used_for_jit=*/true);
place, program_desc.Block(0), scope, execution_config);
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
cached_value.core_ = core;
......
......@@ -118,21 +118,43 @@ inline std::tuple<int, int> GetThreadPoolConfig(const phi::Place& place,
return std::make_tuple(num_host_threads, num_device_threads);
}
ExecutionConfig::ExecutionConfig(const phi::Place& place, size_t op_num) {
std::tie(host_num_threads, deivce_num_threads) =
GetThreadPoolConfig(place, op_num);
void ExecutionConfig::AnalyzeThreadPoolConfig(const phi::Place& place,
size_t op_num) {
if (host_num_threads == 0 || device_num_threads == 0) {
std::tie(host_num_threads, device_num_threads) =
GetThreadPoolConfig(place, op_num);
}
}
void ExecutionConfig::Log(int log_level) {
VLOG(log_level) << "ExecutionConfig:";
VLOG(log_level) << "used_for_jit = " << used_for_jit;
VLOG(log_level) << "create_local_scope = " << create_local_scope;
VLOG(log_level) << "host_num_threads = " << host_num_threads;
VLOG(log_level) << "deivce_num_threads = " << deivce_num_threads;
VLOG(log_level) << "skip_gc_vars = ";
std::stringstream log_str;
log_str << "ExecutionConfig:\n"
<< "create_local_scope = " << create_local_scope << "\n"
<< "used_for_cinn = " << used_for_cinn << "\n"
<< "used_for_control_flow_op = " << used_for_control_flow_op << "\n"
<< "used_for_jit = " << used_for_jit << "\n"
<< "deivce_num_threads = " << device_num_threads << "\n"
<< "host_num_threads = " << host_num_threads << "\n";
log_str << "force_root_scope_vars = [";
for (const std::string& var : force_root_scope_vars) {
log_str << var << " ";
}
log_str << "]\n";
log_str << "jit_input_vars = [";
for (const std::string& var : jit_input_vars) {
log_str << var << " ";
}
log_str << "]\n";
log_str << "skip_gc_vars = [";
for (const std::string& var : skip_gc_vars) {
VLOG(log_level) << var;
log_str << var << " ";
}
log_str << "]\n";
VLOG(log_level) << log_str.str();
}
} // namespace interpreter
......
......@@ -24,18 +24,20 @@ namespace framework {
namespace interpreter {
struct ExecutionConfig {
bool used_for_jit{false};
bool create_local_scope{true};
bool used_for_cinn{false};
bool used_for_control_flow_op{false};
bool used_for_jit{false};
bool create_local_scope{true};
size_t host_num_threads;
size_t deivce_num_threads;
size_t device_num_threads{0};
size_t host_num_threads{0};
std::set<std::string> skip_gc_vars;
std::set<std::string> force_root_scope_vars;
std::set<std::string> jit_input_vars;
std::set<std::string> skip_gc_vars;
ExecutionConfig(const phi::Place& place, size_t op_num);
void AnalyzeThreadPoolConfig(const phi::Place& place, size_t op_num);
void Log(int log_level);
};
......
......@@ -276,15 +276,16 @@ GetUnusedVars(const BlockDesc& block,
}
void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope,
bool use_local_scope) {
const ExecutionConfig& execution_config,
VariableScope* var_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
Scope* local_scope = execution_config.create_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
......@@ -295,7 +296,8 @@ void BuildVariableScope(const framework::BlockDesc& block,
continue;
}
if (var_desc->Persistable()) {
if (var_desc->Persistable() ||
execution_config.force_root_scope_vars.count(var_name)) {
// In principle, we should put all trainable parameters in global scope,
// which means the root of the scope tree. Some cases like quantization
// will look up these parameters in global scope.
......@@ -305,7 +307,6 @@ void BuildVariableScope(const framework::BlockDesc& block,
}
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType());
......@@ -315,8 +316,7 @@ void BuildVariableScope(const framework::BlockDesc& block,
auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type "
<< static_cast<int>(var_desc->GetType());
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
}
var_scope->AddVar(var_name, var_desc);
}
......
......@@ -91,8 +91,8 @@ bool BuildOpFuncList(const platform::Place& place,
bool use_local_scope = true);
void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope,
bool use_local_scope = true);
const ExecutionConfig& execution_config,
VariableScope* var_scope);
void LogDeviceMemoryStats(const platform::Place& place);
......
......@@ -106,32 +106,24 @@ inline void SetDeviceId(const platform::Place& place) {
}
}
// TODO(Ruibiao): Pass skip_gc_vars, used_for_jit, and other config messages by
// constructing an interpreter::ExecutionConfig
InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
framework::Scope* scope,
bool used_for_jit,
bool used_for_control_flow_op,
bool used_for_cinn)
const ExecutionConfig& execution_config)
: place_(place),
block_(block),
execution_config_(place, block.OpSize()),
stream_analyzer_(place),
execution_config_(execution_config),
var_scope_(scope) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
execution_config_.used_for_jit = used_for_jit;
execution_config_.used_for_cinn = used_for_cinn;
execution_config_.used_for_control_flow_op = used_for_control_flow_op;
execution_config_.create_local_scope =
!used_for_jit && FLAGS_new_executor_use_local_scope &&
!used_for_control_flow_op && !used_for_cinn;
execution_config_.skip_gc_vars = skip_gc_vars;
if (!FLAGS_new_executor_use_local_scope) {
execution_config_.create_local_scope = false;
}
execution_config_.AnalyzeThreadPoolConfig(place, block.OpSize());
execution_config_.Log(/*log_level=*/8);
if (execution_config_.create_local_scope) {
......@@ -280,7 +272,7 @@ paddle::framework::FetchList InterpreterCore::Run(
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope());
block_, execution_config_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
......@@ -410,7 +402,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
if (async_work_queue_ == nullptr) {
async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
execution_config_.host_num_threads,
execution_config_.deivce_num_threads,
execution_config_.device_num_threads,
nullptr);
}
return async_work_queue_;
......@@ -1271,7 +1263,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
if (!is_build_) {
paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope());
block_, execution_config_, &var_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
......@@ -1309,24 +1301,6 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; }
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
Scope* scope,
const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::AddFetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
core->SetCopyProgram(new_prog);
return core;
}
// Note(zhangbo):
// (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a
......@@ -1462,5 +1436,24 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
trace_execute_order_ = trace_order;
}
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
Scope* scope,
const std::vector<std::string>& fetch_names,
const interpreter::ExecutionConfig& execution_config) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::AddFetch(fetch_names, block);
core =
std::make_shared<InterpreterCore>(place, *block, scope, execution_config);
core->SetCopyProgram(new_prog);
return core;
}
} // namespace framework
} // namespace paddle
......@@ -40,14 +40,18 @@ namespace paddle {
namespace framework {
class InterpreterCore {
using ExecutionConfig = interpreter::ExecutionConfig;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
public:
InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
Scope* scope,
bool used_for_jit = false,
bool used_for_control_flow_op = false,
bool used_for_cinn = false);
const ExecutionConfig& execution_config = ExecutionConfig());
~InterpreterCore();
......@@ -79,12 +83,7 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; }
private:
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
DISABLE_COPY_AND_ASSIGN(InterpreterCore);
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
......@@ -135,11 +134,10 @@ class InterpreterCore {
private:
bool is_build_{false};
platform::Place place_;
const platform::Place place_;
const BlockDesc& block_; // not owned
interpreter::DependencyBuilder dependency_builder_;
interpreter::ExecutionConfig execution_config_;
interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
......@@ -156,6 +154,9 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0};
ExecutionConfig execution_config_;
VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned
......@@ -189,9 +190,10 @@ class InterpreterCore {
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
Scope* global_scope,
Scope* scope,
const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {});
const interpreter::ExecutionConfig& execution_config =
interpreter::ExecutionConfig());
} // namespace framework
} // namespace paddle
......@@ -71,11 +71,7 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (add_fetch_op) {
core = CreateInterpreterCore(place_, prog, scope, fetch_names);
} else {
core = std::make_shared<InterpreterCore>(
place_,
prog.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
scope);
core = std::make_shared<InterpreterCore>(place_, prog.Block(0), scope);
}
interpretercores_.emplace(oss.str(), core);
return core;
......
......@@ -188,8 +188,11 @@ TEST(InterpreterCore, skip_gc_vars) {
"elementwise_add_0.tmp_0",
"tmp_0"};
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &scope, {}, skip_gc_vars);
interpreter::ExecutionConfig execution_config;
execution_config.skip_gc_vars = skip_gc_vars;
std::shared_ptr<InterpreterCore> main_core = CreateInterpreterCore(
place, main_prog, &scope, /*fetch_names=*/{}, execution_config);
auto check_gc_result =
[](Scope& scope, std::set<std::string>& vars, bool is_skip_gc) {
......
......@@ -59,18 +59,17 @@ void InterpreterEngine::CreateInterpreterCore() {
GraphToProgram(graph, &converted_prog_, nullptr);
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
auto in_names = info_->InputArgNames();
auto out_names = info_->OutputArgNames();
std::set<std::string> skip_gc_vars;
skip_gc_vars.insert(in_names.begin(), in_names.end());
skip_gc_vars.insert(out_names.begin(), out_names.end());
inner_interpreter_ =
std::make_shared<InterpreterCore>(place_,
converted_prog_.Block(0),
/*skip_gc_vars=*/skip_gc_vars,
&scope_,
/*used_for_jit=*/true);
execution_config.skip_gc_vars.insert(in_names.begin(), in_names.end());
execution_config.skip_gc_vars.insert(out_names.begin(), out_names.end());
inner_interpreter_ = std::make_shared<InterpreterCore>(
place_, converted_prog_.Block(0), &scope_, execution_config);
}
std::vector<Tensor> InterpreterEngine::operator()(
......
......@@ -488,14 +488,12 @@ framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
framework::proto::VarType::LOD_TENSOR);
}
if (!interpreter_core_) {
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_cinn = true;
execution_config.skip_gc_vars = skip_gc_vars_;
interpreter_core_ = std::make_unique<framework::InterpreterCore>(
place,
runtime_program_desc_->Block(0),
skip_gc_vars_,
scope,
/*used_for_jit*/ false,
/*used_for_control_flow_op*/ false,
/*used_for_cinn*/ true);
place, runtime_program_desc_->Block(0), scope, execution_config);
} else {
interpreter_core_->reset_scope(scope);
}
......
......@@ -99,17 +99,18 @@ class ConditionalBlockOp : public ConditionalOp {
LOG_FIRST_N(INFO, 1)
<< "[ControlFlow][ConditionalBlock] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(),
skip_vars.end());
VLOG(10) << "[interpreterCore cache]" << core_.get();
VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place);
core_.reset(new InterpreterCore(dev_place,
*block,
skip_gc_vars,
&cur_scope,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new InterpreterCore(
dev_place, *block, &cur_scope, execution_config));
VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core_;
} else {
......@@ -214,14 +215,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
VLOG(10) << "[interpreterCore cache]" << core_.get();
VLOG_IF(10, core_)
<< platform::is_same_place(core_->GetPlace(), dev_place);
std::set<std::string> skip_gc_vars(inside_grads.begin(),
inside_grads.end());
core_.reset(new InterpreterCore(dev_place,
*block,
skip_gc_vars,
&cur_scope,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(inside_grads.begin(), inside_grads.end());
core_.reset(new InterpreterCore(
dev_place, *block, &cur_scope, execution_config));
VLOG(10) << "[interpreterCore cache]"
<< "new created:" << core_;
} else {
......
......@@ -202,16 +202,16 @@ class WhileOp : public framework::OperatorBase {
if (FLAGS_control_flow_use_new_executor) {
LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running.";
if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new framework::InterpreterCore(
dev_place,
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
dev_place, *block, &placeholder, execution_config));
}
} else {
if (!executor_ ||
......@@ -398,13 +398,14 @@ class WhileGradOp : public framework::OperatorBase {
std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
framework::Scope placeholder; // Don't care if it's valid, just for
// initialize InterpreterCore
framework::interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_control_flow_op = true;
execution_config.skip_gc_vars =
std::set<std::string>(skip_vars.begin(), skip_vars.end());
core_.reset(new framework::InterpreterCore(
dev_place,
*block,
skip_gc_vars,
&placeholder,
/* used_for_jit */ false,
/* used_for_control_flow_op */ true));
dev_place, *block, &placeholder, execution_config));
}
} else {
if (!executor_ ||
......
......@@ -1624,6 +1624,12 @@ class Executor:
if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
# Temporary manual enable standalone executor for fleet executor,
# delete this code after the FLAGS is removed.
if 'tasks' in program._pipeline_opt["fleet_opt"]:
set_flags(
{"FLAGS_fleet_executor_with_standalone": True}
)
self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor(
program=program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册