未验证 提交 a83a4fab 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13441 from panyx0718/ir2

simplify and hide bcast_params
...@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; ...@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy"; static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const { void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
balance_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName); loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces); places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes); local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
......
...@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t device_id) const; size_t device_id) const;
void Init() const; void Init() const;
private:
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
...@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t GetAppropriateDeviceID( size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const; const std::vector<std::string> &var_names) const;
private: void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
mutable BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -233,30 +233,9 @@ ParallelExecutor::ParallelExecutor( ...@@ -233,30 +233,9 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
const std::unordered_set<std::string> &vars) const { const std::unordered_set<std::string> &vars) const {
// the initializing bcast, all vars would be bcast from device(0), // the initializing bcast, all vars would be bcast from device(0).
// otherwise
// bcast from the specified device.
bool initializing = member_->executor_ ? false : true;
for (auto &var : vars) { for (auto &var : vars) {
int var_dev_id = -1; framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}
}
if (!initializing && var_dev_id == -1) continue;
framework::Variable *main_var = nullptr;
if (initializing) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
...@@ -272,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -272,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices(
auto place = member_->places_[i]; auto place = member_->places_[i];
void *buffer; void *buffer;
if ((initializing && i == 0) || if (i == 0) {
(!initializing && static_cast<int>(i) == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
...@@ -290,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -290,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices(
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
if (initializing) { platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, nccl_ctx.comm_, nccl_ctx.stream());
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (var_dev_id >= 0) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
} }
member_->nccl_ctxs_->WaitAll(); member_->nccl_ctxs_->WaitAll();
} }
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
} else { } else {
platform::CPUPlace cpu; platform::CPUPlace cpu;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
if ((initializing && i == 0) || if (i == 0) continue;
(!initializing && static_cast<int>(i) == var_dev_id))
continue;
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......
...@@ -72,9 +72,9 @@ class ParallelExecutor { ...@@ -72,9 +72,9 @@ class ParallelExecutor {
void Run(const std::vector<std::string> &fetch_tensors, void Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name); const std::string &fetched_var_name);
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const; void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
private:
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册