未验证 提交 1fe4513c 编写于 作者: L Leo Chen 提交者: GitHub

Refine new executor (#37074)

* split declaration and implementation

* remove initdevices

* refine VariableMetaInfo

* add ut

* fix compile
上级 0a92c857
...@@ -31,6 +31,36 @@ ...@@ -31,6 +31,36 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void OpInOutInfo::Build(const OperatorBase *op) {
is_built_ = true;
auto &inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
if (no_need_buffer_ins_.empty()) return;
for (auto &in_name_pair : op->Inputs()) {
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
continue;
}
for (auto &in_arg_name : in_name_pair.second) {
other_args_set_.insert(in_arg_name);
}
}
for (auto &out_name_pair : op->Outputs()) {
for (auto &out_arg_name : out_name_pair.second) {
other_args_set_.insert(out_arg_name);
}
}
}
}
bool OpInOutInfo::IsInArgBufferNeeded(const std::string &in_arg_name) const {
return no_need_buffer_ins_.empty() || other_args_set_.count(in_arg_name) != 0;
}
static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block, static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block,
const std::unordered_set<std::string> &skip_vars) { const std::unordered_set<std::string> &skip_vars) {
if (skip_vars.count(name) != 0) { if (skip_vars.count(name) != 0) {
......
...@@ -33,38 +33,11 @@ class Scope; ...@@ -33,38 +33,11 @@ class Scope;
struct OpInOutInfo { struct OpInOutInfo {
public: public:
void Build(const OperatorBase *op) { void Build(const OperatorBase *op);
is_built_ = true;
auto &inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
if (no_need_buffer_ins_.empty()) return;
for (auto &in_name_pair : op->Inputs()) {
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
continue;
}
for (auto &in_arg_name : in_name_pair.second) {
other_args_set_.insert(in_arg_name);
}
}
for (auto &out_name_pair : op->Outputs()) {
for (auto &out_arg_name : out_name_pair.second) {
other_args_set_.insert(out_arg_name);
}
}
}
}
bool IsBuilt() const { return is_built_; } bool IsBuilt() const { return is_built_; }
bool IsInArgBufferNeeded(const std::string &in_arg_name) const { bool IsInArgBufferNeeded(const std::string &in_arg_name) const;
return no_need_buffer_ins_.empty() ||
other_args_set_.count(in_arg_name) != 0;
}
private: private:
// A set to record unused buffer input vars of op // A set to record unused buffer input vars of op
......
...@@ -3,10 +3,11 @@ lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper l ...@@ -3,10 +3,11 @@ lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper l
graph_to_program_pass variable_helper timer monitor nan_inf_utils) graph_to_program_pass variable_helper timer monitor nan_inf_utils)
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce) cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS}) cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope)
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper)
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog) cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context) cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog new_executor_defs)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager) cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
......
...@@ -121,6 +121,8 @@ void InterpreterCore::Convert() { ...@@ -121,6 +121,8 @@ void InterpreterCore::Convert() {
for (auto var_id : gc_check_input_list) { for (auto var_id : gc_check_input_list) {
vec_meta_info[var_id].var_ref_count_++; vec_meta_info[var_id].var_ref_count_++;
instr.AddGCCheckVar(var_id); instr.AddGCCheckVar(var_id);
VLOG(4) << "clear " << global_scope_->GetNameById(var_id) << " after "
<< instr.OpBase()->Type();
} }
} }
...@@ -131,6 +133,8 @@ void InterpreterCore::Convert() { ...@@ -131,6 +133,8 @@ void InterpreterCore::Convert() {
if (input_var2op_info_.at(id).size() == 0) { if (input_var2op_info_.at(id).size() == 0) {
// output var not be used by any kernel // output var not be used by any kernel
vec_instruction_[i].AddGCCheckVar(id); vec_instruction_[i].AddGCCheckVar(id);
VLOG(4) << "clear " << global_scope_->GetNameById(id) << " after "
<< vec_instruction_[i].OpBase()->Type();
vec_meta_info[id].var_ref_count_++; vec_meta_info[id].var_ref_count_++;
} }
} }
...@@ -437,6 +441,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -437,6 +441,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
try { try {
RunInstruction(instr_node); RunInstruction(instr_node);
// GC infomation
CheckGC(instr_node);
} catch (platform::EnforceNotMet& ex) { } catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex); framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex))); exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
...@@ -463,9 +469,6 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -463,9 +469,6 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
interpreter::RecordEvent(instr_node, place_); interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed); op_run_number_.fetch_add(1, std::memory_order_relaxed);
// GC infomation
CheckGC(instr_node);
RunNextInstructions(instr_node, &ready_ops); RunNextInstructions(instr_node, &ready_ops);
} }
} }
...@@ -476,6 +479,9 @@ void InterpreterCore::CheckGC(const Instruction& instr) { ...@@ -476,6 +479,9 @@ void InterpreterCore::CheckGC(const Instruction& instr) {
auto& atomic_var_ref = async_work_queue_->AtomicVarRef(); auto& atomic_var_ref = async_work_queue_->AtomicVarRef();
for (auto var_id : instr.GCCheckVars()) { for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << global_scope_->GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id);
bool is_ready = bool is_ready =
atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1; atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
// ignore all persistable var while GC // ignore all persistable var while GC
......
...@@ -23,16 +23,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -23,16 +23,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
: place_(place), : place_(place),
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
outer_scope_(scope), global_scope_(VariableScope(scope)) {
global_scope_(scope) {
paddle::framework::InitDevices();
// init scope // init scope
BuildVariableOuterScope(startup_prog, &global_scope_, scope); BuildVariableScope(startup_prog, &global_scope_);
if (outer_scope_ != nullptr) { if (scope != nullptr) {
auto name_list = outer_scope_->LocalVarNames(); auto name_list = scope->LocalVarNames();
for (auto name : name_list) { for (auto name : name_list) {
auto v = outer_scope_->Var(name); auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) { if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v); global_scope_.AddVar(name, *v);
} }
...@@ -62,9 +60,8 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun( ...@@ -62,9 +60,8 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun(
return core->DryRun(feed_names, feed_tensors); return core->DryRun(feed_names, feed_tensors);
} }
void StandaloneExecutor::BuildVariableOuterScope( void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
const framework::ProgramDesc& pdesc, VariableScope* var_scope, VariableScope* var_scope) {
Scope* outer_scope) {
auto& global_block = pdesc.Block(0); auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
......
...@@ -50,8 +50,8 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -50,8 +50,8 @@ class StandaloneExecutor : public ExecutorBase {
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
private: private:
void BuildVariableOuterScope(const framework::ProgramDesc& pdesc, void BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope, Scope* outer_scope); VariableScope* var_scope);
std::shared_ptr<InterpreterCore> GetInterpreterCore( std::shared_ptr<InterpreterCore> GetInterpreterCore(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
...@@ -60,7 +60,6 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -60,7 +60,6 @@ class StandaloneExecutor : public ExecutorBase {
const platform::Place& place_; const platform::Place& place_;
const ProgramDesc& startup_prog_; const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_; const ProgramDesc& main_prog_;
Scope* outer_scope_;
VariableScope global_scope_; VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_; std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
......
...@@ -75,6 +75,7 @@ paddle::framework::ProgramDesc load_from_file(const std::string& file_name) { ...@@ -75,6 +75,7 @@ paddle::framework::ProgramDesc load_from_file(const std::string& file_name) {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
paddle::framework::InitDevices();
std::cout << "main" << std::endl; std::cout << "main" << std::endl;
int64_t batch_size = std::stoi(argv[1]); int64_t batch_size = std::stoi(argv[1]);
paddle::framework::InitDevices(); paddle::framework::InitDevices();
......
...@@ -266,14 +266,14 @@ Variable* Scope::FindVarLocally(const std::string& name) const { ...@@ -266,14 +266,14 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr; return nullptr;
} }
void Scope::AddListener(ScopeListener* listener) { void Scope::AddListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener); auto it = std::find(listeners_.begin(), listeners_.end(), listener);
if (it == listeners_.end()) { if (it == listeners_.end()) {
listeners_.push_back(listener); listeners_.push_back(listener);
} }
} }
void Scope::DelListener(ScopeListener* listener) { void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener); listeners_.remove(listener);
} }
......
...@@ -144,9 +144,9 @@ class Scope : public ScopeBase { ...@@ -144,9 +144,9 @@ class Scope : public ScopeBase {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
void AddListener(ScopeListener* listener); void AddListener(const std::shared_ptr<ScopeListener>& listener);
void DelListener(ScopeListener* listener); void DelListener(const std::shared_ptr<ScopeListener>& listener);
protected: protected:
struct KeyHasher { struct KeyHasher {
...@@ -184,7 +184,7 @@ class Scope : public ScopeBase { ...@@ -184,7 +184,7 @@ class Scope : public ScopeBase {
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::list<ScopeListener*> listeners_; std::list<std::shared_ptr<ScopeListener>> listeners_;
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
......
...@@ -275,7 +275,7 @@ class TestException(unittest.TestCase): ...@@ -275,7 +275,7 @@ class TestException(unittest.TestCase):
for feed in feeds: for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars) out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(out) print(main_program)
return out return out
def run_new_executor(self, feed): def run_new_executor(self, feed):
...@@ -287,10 +287,10 @@ class TestException(unittest.TestCase): ...@@ -287,10 +287,10 @@ class TestException(unittest.TestCase):
def test_exception(self): def test_exception(self):
feed = [{ feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64), 'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32), 'data': np.array([1, 2, 3]).astype(np.float32),
}, { }, {
'id': np.array([1, 2, 3, 4, 11]).astype(np.int64), 'id': np.array([1, 2, 3, 4, 11]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32), 'data': np.array([1, 2, 3]).astype(np.float32),
}] }]
self.assertRaises(ValueError, self.run_new_executor, feed) self.assertRaises(ValueError, self.run_new_executor, feed)
...@@ -307,6 +307,18 @@ class TestException(unittest.TestCase): ...@@ -307,6 +307,18 @@ class TestException(unittest.TestCase):
feed[1]['data'][0] = np.nan feed[1]['data'][0] = np.nan
self.assertRaises(RuntimeError, self.run_new_executor, feed) self.assertRaises(RuntimeError, self.run_new_executor, feed)
def test_scope(self):
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32),
}, {
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([2, 2, 2]).astype(np.float32),
}]
self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var(
'embedding.tmp_2'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册