提交 681514e1 编写于 作者: M minqiyang

Make all scope pointer to shared

上级 ce24a920
...@@ -22,7 +22,8 @@ namespace framework { ...@@ -22,7 +22,8 @@ namespace framework {
namespace details { namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy,
const std::vector<std::shared_ptr<Scope>> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph) std::unique_ptr<ir::Graph> &&graph)
: strategy_(strategy), : strategy_(strategy),
......
...@@ -29,8 +29,9 @@ namespace details { ...@@ -29,8 +29,9 @@ namespace details {
class OpHandleBase; class OpHandleBase;
class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, FastThreadedSSAGraphExecutor(
const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy,
const std::vector<std::shared_ptr<Scope>> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
...@@ -38,7 +39,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -38,7 +39,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<std::shared_ptr<Scope>> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::unique_ptr<ir::Graph> graph_; std::unique_ptr<ir::Graph> graph_;
......
...@@ -22,7 +22,7 @@ namespace framework { ...@@ -22,7 +22,7 @@ namespace framework {
namespace details { namespace details {
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes) std::vector<std::shared_ptr<Scope>> *local_scopes)
: OpHandleBase(node), : OpHandleBase(node),
data_(data), data_(data),
offset_(offset), offset_(offset),
......
...@@ -29,7 +29,7 @@ namespace details { ...@@ -29,7 +29,7 @@ namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
public: public:
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes); std::vector<std::shared_ptr<Scope>> *local_scopes);
~FetchOpHandle(); ~FetchOpHandle();
...@@ -47,7 +47,7 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -47,7 +47,7 @@ struct FetchOpHandle : public OpHandleBase {
private: private:
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
std::vector<Scope *> *local_scopes_; std::vector<std::shared_ptr<Scope>> *local_scopes_;
std::vector<LoDTensor> tensors_; std::vector<LoDTensor> tensors_;
}; };
......
...@@ -23,7 +23,8 @@ namespace paddle { ...@@ -23,7 +23,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes, ExecutionStrategy strategy,
std::vector<std::shared_ptr<Scope>> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor) std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
......
...@@ -37,7 +37,8 @@ struct VariableInfo { ...@@ -37,7 +37,8 @@ struct VariableInfo {
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes, ExecutionStrategy strategy,
std::vector<std::shared_ptr<Scope>> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor); std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
...@@ -52,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -52,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_; std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_; std::vector<std::shared_ptr<Scope>> local_scopes_;
std::vector<VariableInfo> var_infos_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
}; };
......
...@@ -21,7 +21,8 @@ namespace paddle { ...@@ -21,7 +21,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy,
const std::vector<std::shared_ptr<Scope>> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph) std::unique_ptr<ir::Graph> &&graph)
: graph_(std::move(graph)), : graph_(std::move(graph)),
......
...@@ -38,8 +38,9 @@ namespace details { ...@@ -38,8 +38,9 @@ namespace details {
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(
const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy,
const std::vector<std::shared_ptr<Scope>> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
...@@ -57,7 +58,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -57,7 +58,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private: private:
std::unique_ptr<ir::Graph> graph_; std::unique_ptr<ir::Graph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<std::shared_ptr<Scope>> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
......
...@@ -39,7 +39,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass( ...@@ -39,7 +39,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
const ProgramDesc &main_program, const std::vector<platform::Place> &places, const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &param_names, const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes, const bool use_cuda, const std::vector<std::shared_ptr<Scope>> &local_scopes,
const bool use_cuda,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) {
#else #else
...@@ -66,8 +67,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass( ...@@ -66,8 +67,8 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
&loss_var_name); &loss_var_name);
multi_devices_pass->SetNotOwned<const std::unordered_set<std::string>>( multi_devices_pass->SetNotOwned<const std::unordered_set<std::string>>(
"params", &param_names); "params", &param_names);
multi_devices_pass->SetNotOwned<const std::vector<Scope *>>("local_scopes", multi_devices_pass->SetNotOwned<const std::vector<std::shared_ptr<Scope>>>(
&local_scopes); "local_scopes", &local_scopes);
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy); multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", &strategy);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -100,8 +101,8 @@ class ParallelExecutorPrivate { ...@@ -100,8 +101,8 @@ class ParallelExecutorPrivate {
: places_(places) {} : places_(places) {}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_; std::vector<std::shared_ptr<Scope>> local_scopes_;
Scope *global_scope_; std::shared_ptr<Scope> global_scope_;
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -112,7 +113,7 @@ class ParallelExecutorPrivate { ...@@ -112,7 +113,7 @@ class ParallelExecutorPrivate {
bool use_all_reduce_; bool use_all_reduce_;
}; };
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<std::shared_ptr<Scope>> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_; return member_->local_scopes_;
} }
...@@ -121,7 +122,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -121,7 +122,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, const std::shared_ptr<Scope> &scope,
const std::vector<std::shared_ptr<Scope>> &local_scopes,
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
size_t num_trainers, size_t trainer_id) size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
...@@ -142,13 +144,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -142,13 +144,13 @@ ParallelExecutor::ParallelExecutor(
member_->own_local_scope_ = true; member_->own_local_scope_ = true;
member_->local_scopes_.emplace_back(member_->global_scope_); member_->local_scopes_.emplace_back(member_->global_scope_);
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(&scope->NewScope()); member_->local_scopes_.emplace_back(scope->NewSharedScope());
} }
} else { } else {
member_->own_local_scope_ = false; member_->own_local_scope_ = false;
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size()); PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope()); member_->local_scopes_.emplace_back(local_scopes[i]->NewSharedScope());
} }
} }
...@@ -321,7 +323,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( ...@@ -321,7 +323,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i]; auto &map = tensors[i];
auto *scope = member_->local_scopes_[i]; auto &scope = member_->local_scopes_[i];
for (auto &pair : map) { for (auto &pair : map) {
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>(); auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
trg->ShareDataWith(pair.second); trg->ShareDataWith(pair.second);
...@@ -351,8 +353,15 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -351,8 +353,15 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
if (member_->own_local_scope_) { if (member_->own_local_scope_) {
std::vector<Scope *> local_scopes_ptrs;
local_scopes_ptrs.reserve(member_->local_scopes_.size());
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
member_->global_scope_->DeleteScope(member_->local_scopes_[i]); local_scopes_ptrs.emplace_back(member_->local_scopes_[i].get());
member_->local_scopes_[i].reset();
}
for (size_t i = 0; i != local_scopes_ptrs.size(); ++i) {
member_->global_scope_->DeleteScope(local_scopes_ptrs[i]);
} }
} }
} }
......
...@@ -39,19 +39,20 @@ class ParallelExecutor { ...@@ -39,19 +39,20 @@ class ParallelExecutor {
DISABLE_COPY_AND_ASSIGN(ParallelExecutor); DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
public: public:
explicit ParallelExecutor(const std::vector<platform::Place> &places, explicit ParallelExecutor(
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const ProgramDesc &main_program, const std::string &loss_var_name,
const std::string &loss_var_name, Scope *scope, const std::shared_ptr<Scope> &scope,
const std::vector<Scope *> &local_scopes, const std::vector<std::shared_ptr<Scope>> &local_scopes,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy, size_t num_trainers = 1,
size_t num_trainers = 1, size_t trainer_id = 0); size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();
std::vector<Scope *> &GetLocalScopes(); std::vector<std::shared_ptr<Scope>> &GetLocalScopes();
/** /**
* Feed tensors to local scopes. The size of tensors should be equal to the * Feed tensors to local scopes. The size of tensors should be equal to the
......
...@@ -38,8 +38,8 @@ Scope::~Scope() { DropKids(); } ...@@ -38,8 +38,8 @@ Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this)); kids_.push_back(std::shared_ptr<Scope>(new Scope(this)));
return *kids_.back(); return kids_.back().get();
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
...@@ -68,7 +68,6 @@ const Scope* Scope::FindScope(const Variable* var) const { ...@@ -68,7 +68,6 @@ const Scope* Scope::FindScope(const Variable* var) const {
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
...@@ -84,8 +83,12 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -84,8 +83,12 @@ std::vector<std::string> Scope::LocalVarNames() const {
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find_if(this->kids_.begin(), this->kids_.end(),
[&scope](const std::shared_ptr<Scope>& kid) {
return kid.get() == scope;
});
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
it->reset();
this->kids_.erase(it); this->kids_.erase(it);
// When making memory benchmark on Fluid, we have to delete scope sync. // When making memory benchmark on Fluid, we have to delete scope sync.
if (FLAGS_benchmark || FLAGS_eager_delete_scope) { if (FLAGS_benchmark || FLAGS_eager_delete_scope) {
......
...@@ -105,7 +105,7 @@ class Scope { ...@@ -105,7 +105,7 @@ class Scope {
Variable* FindVarLocally(const std::string& name) const; Variable* FindVarLocally(const std::string& name) const;
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<std::shared_ptr<Scope>> kids_;
Scope const* parent_{nullptr}; Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册