提交 cb8a24be 编写于 作者: Y Yancey1989

clean code

上级 c9de6f1b
...@@ -46,9 +46,6 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -46,9 +46,6 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
int64_t start_ts = GetTS();
int64_t func_ts = GetTS();
VLOG(5) << "all_reduce_op_handle::RunImpl start";
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR, // FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
...@@ -62,11 +59,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -62,11 +59,7 @@ void AllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done
start_ts = GetTS();
WaitInputVarGenerated(); WaitInputVarGenerated();
VLOG(5) << "all_reduce_op_handle wait input var spent: "
<< GetTS() - start_ts << " (ns).";
start_ts = GetTS();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs()); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs()); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -107,8 +100,6 @@ void AllReduceOpHandle::RunImpl() { ...@@ -107,8 +100,6 @@ void AllReduceOpHandle::RunImpl() {
} }
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
VLOG(5) << "call allreduce: " << in_var_handles[i]->name_
<< " on dev: " << dev_id;
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
...@@ -118,6 +109,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -118,6 +109,7 @@ void AllReduceOpHandle::RunImpl() {
ncclSum, comm, stream)); ncclSum, comm, stream));
}); });
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
// TODO(Yancey1989): need allreduce operator to avoid this flag // TODO(Yancey1989): need allreduce operator to avoid this flag
if (nccl_ctxs_->need_group_call_) { if (nccl_ctxs_->need_group_call_) {
...@@ -162,8 +154,6 @@ void AllReduceOpHandle::RunImpl() { ...@@ -162,8 +154,6 @@ void AllReduceOpHandle::RunImpl() {
} }
} }
} }
VLOG(5) << "all_reduce_op_handle Impl spent: " << GetTS() - func_ts
<< " (ns).";
} }
std::string AllReduceOpHandle::Name() const { return "all_reduce"; } std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
......
...@@ -33,18 +33,10 @@ void ComputationOpHandle::RunImpl() { ...@@ -33,18 +33,10 @@ void ComputationOpHandle::RunImpl() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}; };
if (Name().compare("conv2d") || Name().compare("conv2d_grad")) { if (is_lock_and_record_event_free_) {
int64_t start_ts = GetTS();
auto varname = DynamicCast<VarHandle>(this->Outputs())[0]->name_;
run_func(); run_func();
VLOG(5) << Name() << "_op_handle: " << varname
<< " spent: " << GetTS() - start_ts << " (ns).";
} else { } else {
if (is_lock_and_record_event_free_) { this->RunAndRecordEvent(run_func);
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
} }
} }
......
...@@ -41,7 +41,6 @@ OpHandleBase::~OpHandleBase() { ...@@ -41,7 +41,6 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
int64_t start_ts = 0;
if (events_.empty() && use_cuda) { if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
...@@ -125,7 +124,6 @@ bool OpHandleBase::NeedWait(VarHandleBase *in_var) { ...@@ -125,7 +124,6 @@ bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
VLOG(5) << "events not empty";
std::function<void()> method = callback; std::function<void()> method = callback;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
method = [method, p, this]() { method = [method, p, this]() {
......
...@@ -21,19 +21,20 @@ namespace details { ...@@ -21,19 +21,20 @@ namespace details {
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs) std::vector<std::unique_ptr<ir::Graph>> &&graphs)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graphs_(std::move(graphs)), graphs_(std::move(graphs)) {
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// do not use threadpool for each graph execution.
strategy_.num_threads_ = 1UL;
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
std::vector<framework::Scope *> scopes = {local_scopes_[i]};
std::vector<platform::Place> places = {places_[i]};
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, scopes, places, std::move(graphs_[i]))); strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
} }
VLOG(1) << "pool size: " << places_.size();
} }
FeedFetchList ParallelSSAGraphExecutor::Run( FeedFetchList ParallelSSAGraphExecutor::Run(
......
...@@ -30,7 +30,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -30,7 +30,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs); std::vector<std::unique_ptr<ir::Graph>> &&graphs);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -39,9 +39,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -39,9 +39,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<std::unique_ptr<ir::Graph>> graphs_; std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
}; };
......
...@@ -54,7 +54,6 @@ class ParallelExecutorPrivate { ...@@ -54,7 +54,6 @@ class ParallelExecutorPrivate {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; // not owned Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
std::vector<std::unique_ptr<details::SSAGraphExecutor>> executors_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
...@@ -142,6 +141,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -142,6 +141,7 @@ ParallelExecutor::ParallelExecutor(
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
VLOG(1) << "kParallelGraph mode!!";
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, params, main_program, {member_->places_[i]}, loss_var_name, params,
...@@ -222,38 +222,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -222,38 +222,17 @@ ParallelExecutor::ParallelExecutor(
} }
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
/**
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::vector<details::VariableInfo> var_infos;
for (auto &node : graphs[i]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
std::vector<platform::Place> places = {member_->places_[i]};
std::vector<framework::Scope *> scopes = {member_->local_scopes_[i]};
std::unique_ptr<details::ThreadedSSAGraphExecutor> p(new
details::ThreadedSSAGraphExecutor(
exec_strategy, scopes, places, std::move(graphs[i])));
member_->executors_.push_back(std::move(p));
member_->executors_[i].reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, scopes, std::move(var_infos), places,
std::move(member_->executors_[i])));
}**/
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graphs[0]))); exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { } else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, graphs)); exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs)));
} else { } else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graphs[0]))); exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
} }
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
......
...@@ -105,7 +105,7 @@ struct NCCLContextMap { ...@@ -105,7 +105,7 @@ struct NCCLContextMap {
} }
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms. // if num_trainers == 1, should create a new nccl id for local comms.
if (num_trainers == 1 && nccl_id != nullptr) { if (num_trainers == 1 && nccl_id == nullptr) {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex()); std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册