提交 e5b32205 编写于 作者: X Xin Pan

clean

上级 ec6ee0a2
...@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; ...@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy"; static const char kStrategy[] = "strategy";
void MultiDevSSAGraphBuilder::Init() const { void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear();
balance_vars_.clear();
loss_var_name_ = Get<const std::string>(kLossVarName); loss_var_name_ = Get<const std::string>(kLossVarName);
places_ = Get<const std::vector<platform::Place>>(kPlaces); places_ = Get<const std::vector<platform::Place>>(kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes); local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
......
...@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t device_id) const; size_t device_id) const;
void Init() const; void Init() const;
private:
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
...@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t GetAppropriateDeviceID( size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const; const std::vector<std::string> &var_names) const;
private: void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
mutable BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::vector<int64_t> balance_vars_; mutable std::vector<int64_t> balance_vars_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册