未验证 提交 1fe4513c 编写于 作者: L Leo Chen 提交者: GitHub

Refine new executor (#37074)

* split declaration and implementation

* remove initdevices

* refine VariableMetaInfo

* add ut

* fix compile
上级 0a92c857
......@@ -31,6 +31,36 @@
namespace paddle {
namespace framework {
void OpInOutInfo::Build(const OperatorBase *op) {
is_built_ = true;
auto &inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
if (no_need_buffer_ins_.empty()) return;
for (auto &in_name_pair : op->Inputs()) {
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
continue;
}
for (auto &in_arg_name : in_name_pair.second) {
other_args_set_.insert(in_arg_name);
}
}
for (auto &out_name_pair : op->Outputs()) {
for (auto &out_arg_name : out_name_pair.second) {
other_args_set_.insert(out_arg_name);
}
}
}
}
bool OpInOutInfo::IsInArgBufferNeeded(const std::string &in_arg_name) const {
return no_need_buffer_ins_.empty() || other_args_set_.count(in_arg_name) != 0;
}
static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block,
const std::unordered_set<std::string> &skip_vars) {
if (skip_vars.count(name) != 0) {
......
......@@ -33,38 +33,11 @@ class Scope;
struct OpInOutInfo {
public:
void Build(const OperatorBase *op) {
is_built_ = true;
auto &inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
if (no_need_buffer_ins_.empty()) return;
for (auto &in_name_pair : op->Inputs()) {
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
continue;
}
for (auto &in_arg_name : in_name_pair.second) {
other_args_set_.insert(in_arg_name);
}
}
for (auto &out_name_pair : op->Outputs()) {
for (auto &out_arg_name : out_name_pair.second) {
other_args_set_.insert(out_arg_name);
}
}
}
}
void Build(const OperatorBase *op);
bool IsBuilt() const { return is_built_; }
bool IsInArgBufferNeeded(const std::string &in_arg_name) const {
return no_need_buffer_ins_.empty() ||
other_args_set_.count(in_arg_name) != 0;
}
bool IsInArgBufferNeeded(const std::string &in_arg_name) const;
private:
// A set to record unused buffer input vars of op
......
......@@ -3,10 +3,11 @@ lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper l
graph_to_program_pass variable_helper timer monitor nan_inf_utils)
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS})
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue)
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context)
cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper)
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs)
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog new_executor_defs)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
......
......@@ -121,6 +121,8 @@ void InterpreterCore::Convert() {
for (auto var_id : gc_check_input_list) {
vec_meta_info[var_id].var_ref_count_++;
instr.AddGCCheckVar(var_id);
VLOG(4) << "clear " << global_scope_->GetNameById(var_id) << " after "
<< instr.OpBase()->Type();
}
}
......@@ -131,6 +133,8 @@ void InterpreterCore::Convert() {
if (input_var2op_info_.at(id).size() == 0) {
// output var not be used by any kernel
vec_instruction_[i].AddGCCheckVar(id);
VLOG(4) << "clear " << global_scope_->GetNameById(id) << " after "
<< vec_instruction_[i].OpBase()->Type();
vec_meta_info[id].var_ref_count_++;
}
}
......@@ -437,6 +441,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
try {
RunInstruction(instr_node);
// GC infomation
CheckGC(instr_node);
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
......@@ -463,9 +469,6 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
interpreter::RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed);
// GC infomation
CheckGC(instr_node);
RunNextInstructions(instr_node, &ready_ops);
}
}
......@@ -476,6 +479,9 @@ void InterpreterCore::CheckGC(const Instruction& instr) {
auto& atomic_var_ref = async_work_queue_->AtomicVarRef();
for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << global_scope_->GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id);
bool is_ready =
atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
// ignore all persistable var while GC
......
......@@ -23,16 +23,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
: place_(place),
startup_prog_(startup_prog),
main_prog_(main_prog),
outer_scope_(scope),
global_scope_(scope) {
paddle::framework::InitDevices();
global_scope_(VariableScope(scope)) {
// init scope
BuildVariableOuterScope(startup_prog, &global_scope_, scope);
BuildVariableScope(startup_prog, &global_scope_);
if (outer_scope_ != nullptr) {
auto name_list = outer_scope_->LocalVarNames();
if (scope != nullptr) {
auto name_list = scope->LocalVarNames();
for (auto name : name_list) {
auto v = outer_scope_->Var(name);
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
}
......@@ -62,9 +60,8 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun(
return core->DryRun(feed_names, feed_tensors);
}
void StandaloneExecutor::BuildVariableOuterScope(
const framework::ProgramDesc& pdesc, VariableScope* var_scope,
Scope* outer_scope) {
void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) {
......
......@@ -50,8 +50,8 @@ class StandaloneExecutor : public ExecutorBase {
const std::vector<framework::LoDTensor>& feed_tensors);
private:
void BuildVariableOuterScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope, Scope* outer_scope);
void BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope);
std::shared_ptr<InterpreterCore> GetInterpreterCore(
const std::vector<std::string>& feed_names,
......@@ -60,7 +60,6 @@ class StandaloneExecutor : public ExecutorBase {
const platform::Place& place_;
const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_;
Scope* outer_scope_;
VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
......
......@@ -75,6 +75,7 @@ paddle::framework::ProgramDesc load_from_file(const std::string& file_name) {
}
int main(int argc, char* argv[]) {
paddle::framework::InitDevices();
std::cout << "main" << std::endl;
int64_t batch_size = std::stoi(argv[1]);
paddle::framework::InitDevices();
......
......@@ -266,14 +266,14 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr;
}
void Scope::AddListener(ScopeListener* listener) {
void Scope::AddListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
if (it == listeners_.end()) {
listeners_.push_back(listener);
}
}
void Scope::DelListener(ScopeListener* listener) {
void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener);
}
......
......@@ -144,9 +144,9 @@ class Scope : public ScopeBase {
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
void AddListener(ScopeListener* listener);
void AddListener(const std::shared_ptr<ScopeListener>& listener);
void DelListener(ScopeListener* listener);
void DelListener(const std::shared_ptr<ScopeListener>& listener);
protected:
struct KeyHasher {
......@@ -184,7 +184,7 @@ class Scope : public ScopeBase {
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr};
std::list<ScopeListener*> listeners_;
std::list<std::shared_ptr<ScopeListener>> listeners_;
DISABLE_COPY_AND_ASSIGN(Scope);
......
......@@ -275,7 +275,7 @@ class TestException(unittest.TestCase):
for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(out)
print(main_program)
return out
def run_new_executor(self, feed):
......@@ -287,10 +287,10 @@ class TestException(unittest.TestCase):
def test_exception(self):
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32),
'data': np.array([1, 2, 3]).astype(np.float32),
}, {
'id': np.array([1, 2, 3, 4, 11]).astype(np.int64),
'data': np.array([1, 2, 3, 4]).astype(np.float32),
'data': np.array([1, 2, 3]).astype(np.float32),
}]
self.assertRaises(ValueError, self.run_new_executor, feed)
......@@ -307,6 +307,18 @@ class TestException(unittest.TestCase):
feed[1]['data'][0] = np.nan
self.assertRaises(RuntimeError, self.run_new_executor, feed)
def test_scope(self):
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32),
}, {
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([2, 2, 2]).astype(np.float32),
}]
self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var(
'embedding.tmp_2'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册