提交 73005ee0 编写于 作者: Y Yancey1989

cleanup code test=develop

上级 88d3dc94
...@@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
auto multi_devices_pass = AppendPass("multi_devices_check_pass"); auto multi_devices_pass = AppendPass("multi_devices_check_pass");
multi_devices_pass->Set<bool>(kEnablePG,
new bool(strategy.enable_parallel_graph_));
if (SeqOnlyAllReduceOps(strategy)) { if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
...@@ -194,8 +192,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -194,8 +192,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&local_scopes); &local_scopes);
pass->Erase(kNRanks); pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks)); pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kEnablePG);
pass->Set<bool>(kEnablePG, new bool(true));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......
...@@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, node, p_name, g_name); InsertCollectiveOp(&result, p_name, g_name);
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
} }
...@@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, ...@@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
} }
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, ir::Node *node, const std::string &og) const { ir::Graph *result, const std::string &og) const {
OpHandleBase *op_handle = nullptr; OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&]( auto append_allreduce_op = [&](
...@@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient( ...@@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
} }
void AllReduceSSAGraphBuilder::InsertCollectiveOp( void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, ir::Node *node, const std::string &p_name, ir::Graph *result, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
if (IsSparseGradient(g_name)) { if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, node, g_name); CreateAllReduceOp(result, g_name);
} }
} }
...@@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const { ...@@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
} }
void ReduceSSAGraphBuilder::InsertCollectiveOp( void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, ir::Node *node, const std::string &p_name, ir::Graph *result, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name}); size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id); CreateReduceOp(result, g_name, cur_device_id);
...@@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id; return op_dev_id;
} }
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
size_t cur_device_id = 0; size_t cur_device_id = 0;
...@@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node, ...@@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, node, g_name); CreateAllReduceOp(result, g_name);
} }
break; break;
default: default:
...@@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { ...@@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \ .RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \ .RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \ .RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks) \ .RequirePassAttr(paddle::framework::details::kNRanks)
.RequirePassAttr(paddle::framework::details::kEnablePG)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder); paddle::framework::details::ReduceSSAGraphBuilder);
......
...@@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places"; ...@@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy"; constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks"; constexpr char kNRanks[] = "nranks";
constexpr char kEnablePG[] = "enable_pg";
class MultiDevSSAGraphBuilderBase : public ir::Pass { class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
...@@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const; virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &p_name,
const std::string &g_name) const = 0; const std::string &g_name) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
...@@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, ir::Node *node, void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
...@@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
...@@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected: protected:
virtual void Init() const; virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
...@@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual void InsertPostprocessOps(ir::Graph *result) const; virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node, virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual void ResetState() const; virtual void ResetState() const;
......
...@@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private: private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
......
...@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op); ready_ops.insert(op);
......
...@@ -176,12 +176,6 @@ class Graph { ...@@ -176,12 +176,6 @@ class Graph {
return ret; return ret;
} }
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE low performance, but simple and secure. // NOTE low performance, but simple and secure.
Node *RetrieveNode(int id) { Node *RetrieveNode(int id) {
for (auto &node : nodes_) { for (auto &node : nodes_) {
...@@ -200,10 +194,6 @@ class Graph { ...@@ -200,10 +194,6 @@ class Graph {
return node; return node;
} }
bool ContainNode(ir::Node *node) {
return node_set_.find(node) != node_set_.end();
}
void ResolveHazard( void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes); const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
......
...@@ -64,9 +64,7 @@ template <typename T> ...@@ -64,9 +64,7 @@ template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) { std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret; std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) { for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) { if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
ret.push_back(&n->Wrapper<T>());
}
} }
return ret; return ret;
} }
......
...@@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution(
} }
} }
// if (!member_->use_all_reduce_ || !member_->use_cuda_) if (!member_->use_all_reduce_ || !member_->use_cuda_)
if (!member_->use_all_reduce_) enable_parallel_graph = false;
if (build_strategy.enable_sequential_execution_ || if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false; enable_parallel_graph = false;
return enable_parallel_graph; return enable_parallel_graph;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册