提交 73005ee0 编写于 作者: Y Yancey1989

cleanup code test=develop

上级 88d3dc94
......@@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
auto multi_devices_pass = AppendPass("multi_devices_check_pass");
multi_devices_pass->Set<bool>(kEnablePG,
new bool(strategy.enable_parallel_graph_));
if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass");
......@@ -194,8 +192,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&local_scopes);
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kEnablePG);
pass->Set<bool>(kEnablePG, new bool(true));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......
......@@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, node, p_name, g_name);
InsertCollectiveOp(&result, p_name, g_name);
}
} catch (boost::bad_get e) {
}
......@@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
}
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, ir::Node *node, const std::string &og) const {
ir::Graph *result, const std::string &og) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
......@@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
}
void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, ir::Node *node, const std::string &p_name,
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, node, g_name);
CreateAllReduceOp(result, g_name);
}
}
......@@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
}
void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, ir::Node *node, const std::string &p_name,
ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id);
......@@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
size_t cur_device_id = 0;
......@@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, node, g_name);
CreateAllReduceOp(result, g_name);
}
break;
default:
......@@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks) \
.RequirePassAttr(paddle::framework::details::kEnablePG)
.RequirePassAttr(paddle::framework::details::kNRanks)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder);
......
......@@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
constexpr char kEnablePG[] = "enable_pg";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
......@@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
......@@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, ir::Node *node,
const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
......@@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
......@@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
......@@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const;
virtual void ResetState() const;
......
......@@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
......
......@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var);
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op);
......
......@@ -176,12 +176,6 @@ class Graph {
return ret;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE low performance, but simple and secure.
Node *RetrieveNode(int id) {
for (auto &node : nodes_) {
......@@ -200,10 +194,6 @@ class Graph {
return node;
}
bool ContainNode(ir::Node *node) {
return node_set_.find(node) != node_set_.end();
}
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
......
......@@ -64,9 +64,7 @@ template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) {
ret.push_back(&n->Wrapper<T>());
}
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
......
......@@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution(
}
}
// if (!member_->use_all_reduce_ || !member_->use_cuda_)
if (!member_->use_all_reduce_) enable_parallel_graph = false;
if (!member_->use_all_reduce_ || !member_->use_cuda_)
if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false;
if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false;
return enable_parallel_graph;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册