From e5b322051b13811747bc5244a093cdb22caceeb6 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 17 Sep 2018 18:53:04 +0800 Subject: [PATCH] clean --- .../details/multi_devices_graph_pass.cc | 3 +++ .../details/multi_devices_graph_pass.h | 18 ++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 250e093a5f7..8f319116ab8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; void MultiDevSSAGraphBuilder::Init() const { + all_vars_.clear(); + balance_vars_.clear(); + loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 1ca8c4b855f..47aaa80f4d6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t device_id) const; void Init() const; - private: - mutable std::string loss_var_name_; - mutable std::vector places_; - mutable std::vector local_scopes_; - mutable std::unordered_set grad_names_; - #ifdef PADDLE_WITH_CUDA mutable platform::NCCLContextMap *nccl_ctxs_; #endif @@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t GetAppropriateDeviceID( const std::vector &var_names) const; - private: + void SetCommunicationContext(OpHandleBase *op_handle, + const platform::Place &p) const; + + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; mutable std::vector balance_vars_; - - void SetCommunicationContext(OpHandleBase *op_handle, - const platform::Place &p) const; }; } // namespace details } // namespace framework -- GitLab