提交 cb8a24be 编写于 作者: Y Yancey1989

clean code

上级 c9de6f1b
......@@ -46,9 +46,6 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
void AllReduceOpHandle::RunImpl() {
int64_t start_ts = GetTS();
int64_t func_ts = GetTS();
VLOG(5) << "all_reduce_op_handle::RunImpl start";
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
......@@ -62,11 +59,7 @@ void AllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
start_ts = GetTS();
WaitInputVarGenerated();
VLOG(5) << "all_reduce_op_handle wait input var spent: "
<< GetTS() - start_ts << " (ns).";
start_ts = GetTS();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
......@@ -107,8 +100,6 @@ void AllReduceOpHandle::RunImpl() {
}
int dev_id = boost::get<platform::CUDAPlace>(p).device;
VLOG(5) << "call allreduce: " << in_var_handles[i]->name_
<< " on dev: " << dev_id;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
......@@ -118,6 +109,7 @@ void AllReduceOpHandle::RunImpl() {
ncclSum, comm, stream));
});
}
this->RunAndRecordEvent([&] {
// TODO(Yancey1989): need allreduce operator to avoid this flag
if (nccl_ctxs_->need_group_call_) {
......@@ -162,8 +154,6 @@ void AllReduceOpHandle::RunImpl() {
}
}
}
VLOG(5) << "all_reduce_op_handle Impl spent: " << GetTS() - func_ts
<< " (ns).";
}
std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
......
......@@ -33,19 +33,11 @@ void ComputationOpHandle::RunImpl() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (Name().compare("conv2d") || Name().compare("conv2d_grad")) {
int64_t start_ts = GetTS();
auto varname = DynamicCast<VarHandle>(this->Outputs())[0]->name_;
run_func();
VLOG(5) << Name() << "_op_handle: " << varname
<< " spent: " << GetTS() - start_ts << " (ns).";
} else {
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
}
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
......@@ -41,7 +41,6 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
int64_t start_ts = 0;
if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
......@@ -125,7 +124,6 @@ bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
VLOG(5) << "events not empty";
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
......
......@@ -21,19 +21,20 @@ namespace details {
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs)
std::vector<std::unique_ptr<ir::Graph>> &&graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr) {
graphs_(std::move(graphs)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// do not use threadpool for each graph execution.
strategy_.num_threads_ = 1UL;
for (size_t i = 0; i < places.size(); ++i) {
std::vector<framework::Scope *> scopes = {local_scopes_[i]};
std::vector<platform::Place> places = {places_[i]};
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, scopes, places, std::move(graphs_[i])));
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
}
VLOG(1) << "pool size: " << places_.size();
}
FeedFetchList ParallelSSAGraphExecutor::Run(
......
......@@ -30,7 +30,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs);
std::vector<std::unique_ptr<ir::Graph>> &&graphs);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
......@@ -39,9 +39,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
};
......
......@@ -54,7 +54,6 @@ class ParallelExecutorPrivate {
std::vector<Scope *> local_scopes_;
Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_;
std::vector<std::unique_ptr<details::SSAGraphExecutor>> executors_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
......@@ -142,6 +141,7 @@ ParallelExecutor::ParallelExecutor(
std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
VLOG(1) << "kParallelGraph mode!!";
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, params,
......@@ -222,38 +222,17 @@ ParallelExecutor::ParallelExecutor(
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
/**
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::vector<details::VariableInfo> var_infos;
for (auto &node : graphs[i]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
std::vector<platform::Place> places = {member_->places_[i]};
std::vector<framework::Scope *> scopes = {member_->local_scopes_[i]};
std::unique_ptr<details::ThreadedSSAGraphExecutor> p(new
details::ThreadedSSAGraphExecutor(
exec_strategy, scopes, places, std::move(graphs[i])));
member_->executors_.push_back(std::move(p));
member_->executors_[i].reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, scopes, std::move(var_infos), places,
std::move(member_->executors_[i])));
}**/
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graphs[0])));
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, graphs));
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs)));
} else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graphs[0])));
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
}
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
......
......@@ -105,7 +105,7 @@ struct NCCLContextMap {
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms.
if (num_trainers == 1 && nccl_id != nullptr) {
if (num_trainers == 1 && nccl_id == nullptr) {
std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册