提交 f3463ecb 编写于 作者: Y Yancey1989

refine pg execution

上级 46a6cac9
...@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { ...@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling // Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang. // them in multiple threads or processes to avoid hang.
return (!strategy.enable_sequential_execution_ && return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) || strategy.num_trainers_ > 1) &&
strategy.enable_parallel_graph_; !strategy.enable_parallel_graph_;
} }
class ParallelExecutorPassBuilder : public ir::PassBuilder { class ParallelExecutorPassBuilder : public ir::PassBuilder {
...@@ -106,7 +106,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -106,7 +106,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); auto multi_devices_pass = AppendPass("multi_devices_check_pass");
multi_devices_pass->Set<bool>(kEnablePG,
new bool(strategy.enable_parallel_graph_));
if (SeqOnlyAllReduceOps(strategy)) { if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
...@@ -180,6 +182,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -180,6 +182,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&local_scopes); &local_scopes);
pass->Erase(kNRanks); pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks)); pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(kEnablePG);
pass->Set<bool>(kEnablePG, new bool(true));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......
...@@ -36,11 +36,6 @@ namespace framework { ...@@ -36,11 +36,6 @@ namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) { bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>( return boost::get<int>(
...@@ -206,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -206,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name); InsertCollectiveOp(&result, node, p_name, g_name);
} }
} catch (boost::bad_get e) { } catch (boost::bad_get e) {
} }
...@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
result.Erase(kGraphOps); // result.Erase(kGraphOps);
return graph; return graph;
} }
...@@ -391,20 +386,34 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, ...@@ -391,20 +386,34 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
} }
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const { ir::Graph *result, ir::Node *node, const std::string &og) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); scopes, places, nccl_ctxs_));
#else #else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); scopes, places));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back(); return result->Get<GraphOps>(kGraphOps).back();
};
if (!strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(local_scopes_, places_);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto p = places_[i];
std::vector<Scope *> ss{local_scopes_[i]};
std::vector<platform::Place> ps{p};
if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
...@@ -501,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient( ...@@ -501,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
} }
void AllReduceSSAGraphBuilder::InsertCollectiveOp( void AllReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name, ir::Graph *result, ir::Node *node, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
if (IsSparseGradient(g_name)) { if (IsSparseGradient(g_name)) {
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, g_name); CreateAllReduceOp(result, node, g_name);
} }
} }
...@@ -580,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const { ...@@ -580,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
} }
void ReduceSSAGraphBuilder::InsertCollectiveOp( void ReduceSSAGraphBuilder::InsertCollectiveOp(
ir::Graph *result, const std::string &p_name, ir::Graph *result, ir::Node *node, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
size_t cur_device_id = GetAppropriateDeviceID({g_name}); size_t cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(result, g_name, cur_device_id); CreateReduceOp(result, g_name, cur_device_id);
...@@ -900,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -900,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id; return op_dev_id;
} }
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
size_t cur_device_id = 0; size_t cur_device_id = 0;
...@@ -915,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -915,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, g_name); CreateAllReduceOp(result, node, g_name);
} }
break; break;
default: default:
...@@ -966,7 +975,8 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { ...@@ -966,7 +975,8 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \ .RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \ .RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \ .RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks) .RequirePassAttr(paddle::framework::details::kNRanks) \
.RequirePassAttr(paddle::framework::details::kEnablePG)
REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass,
paddle::framework::details::ReduceSSAGraphBuilder); paddle::framework::details::ReduceSSAGraphBuilder);
......
...@@ -36,6 +36,7 @@ constexpr char kPlaces[] = "places"; ...@@ -36,6 +36,7 @@ constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy"; constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks"; constexpr char kNRanks[] = "nranks";
constexpr char kEnablePG[] = "enable_pg";
class MultiDevSSAGraphBuilderBase : public ir::Pass { class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
...@@ -46,7 +47,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -46,7 +47,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const; virtual std::vector<ir::Node *> SortOperations(const ir::Graph &graph) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const = 0; const std::string &g_name) const = 0;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0;
...@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; void CreateAllReduceOp(ir::Graph *result, ir::Node *node,
const std::string &og) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
...@@ -106,7 +109,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -106,7 +109,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
...@@ -135,7 +139,8 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -135,7 +139,8 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected: protected:
virtual void Init() const; virtual void Init() const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const;
...@@ -164,7 +169,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -164,7 +169,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual void InsertPostprocessOps(ir::Graph *result) const; virtual void InsertPostprocessOps(ir::Graph *result) const;
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, virtual void InsertCollectiveOp(ir::Graph *result, ir::Node *node,
const std::string &p_name,
const std::string &g_name) const; const std::string &g_name) const;
virtual void ResetState() const; virtual void ResetState() const;
......
...@@ -36,13 +36,20 @@ namespace details { ...@@ -36,13 +36,20 @@ namespace details {
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>> typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase*> GraphDepVars; typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -70,6 +70,9 @@ class OpHandleBase { ...@@ -70,6 +70,9 @@ class OpHandleBase {
auto it = dev_ctxes_.find(place); auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr; return it != dev_ctxes_.end() ? it->second : nullptr;
} }
const std::map<platform::Place, platform::DeviceContext *> &DeviceContext() {
return dev_ctxes_;
}
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) { void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_; dev_ctxes_[place] = ctx_;
......
...@@ -13,11 +13,74 @@ ...@@ -13,11 +13,74 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps);
}
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first;
#ifdef PADDLE_WITH_CUDA
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->ReleaseNode(op->Node()).release());
for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release());
}
}
for (auto &var : op->Outputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) {
dev_dummys.insert(var);
if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release());
}
}
#else
PADDLE_THROW("Parallel Graph Execution only support CUDAPlace.");
#endif
}
for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
dev_vars.emplace(name_pair.first, name_pair.second);
for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode(
graph->ReleaseNode(version_pair->Node()).release());
}
}
}
}
return graphs;
}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -37,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -37,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); strategy_, local_scopes_, {places_[i]}, std::move(graphs_.at(i))));
} }
} }
......
...@@ -14,16 +14,24 @@ ...@@ -14,16 +14,24 @@
#pragma once #pragma once
#include <fstream>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ThreadPool.h" #include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> graph);
class ParallelSSAGraphExecutor : public SSAGraphExecutor { class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
...@@ -31,11 +39,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -31,11 +39,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs); std::vector<std::unique_ptr<ir::Graph>> &&graphs);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private: private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
......
...@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op); ready_ops.insert(op);
...@@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
......
...@@ -167,6 +167,14 @@ class Graph { ...@@ -167,6 +167,14 @@ class Graph {
return ret; return ret;
} }
std::unique_ptr<ir::Node> ReleaseNode(ir::Node *node) {
std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release());
nodes_.erase(node);
node_set_.erase(node);
return ret;
}
void RemoveNode(ir::Node *node) { void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node); node_set_.erase(node);
...@@ -183,13 +191,6 @@ class Graph { ...@@ -183,13 +191,6 @@ class Graph {
return nullptr; return nullptr;
} }
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
...@@ -198,6 +199,17 @@ class Graph { ...@@ -198,6 +199,17 @@ class Graph {
return node; return node;
} }
bool ContainNode(ir::Node *node) {
return node_set_.find(node) != node_set_.end();
}
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc program_; const ProgramDesc program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
......
...@@ -59,7 +59,9 @@ template <typename T> ...@@ -59,7 +59,9 @@ template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) { std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret; std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) { for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>()); if (n->IsWrappedBy<T>()) {
ret.push_back(&n->Wrapper<T>());
}
} }
return ret; return ret;
} }
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -201,7 +202,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -201,7 +202,6 @@ ParallelExecutor::ParallelExecutor(
member_->use_all_reduce_ = member_->use_all_reduce_ =
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
member_->nranks_ = build_strategy.num_trainers_ * places.size(); member_->nranks_ = build_strategy.num_trainers_ * places.size();
if (!member_->use_all_reduce_) { if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1, PADDLE_ENFORCE(places.size() > 1,
"If you set build_strategy.reduce with 'Reduce'," "If you set build_strategy.reduce with 'Reduce',"
...@@ -229,9 +229,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -229,9 +229,10 @@ ParallelExecutor::ParallelExecutor(
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = build_strategy.enable_parallel_graph_ =
EnableParallelGraphExecution(main_program, exec_strategy, build_strategy); EnableParallelGraphExecution(main_program, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_)
VLOG(1) << "Enable ParallelGraph Execution: " VLOG(0) << "The Executor would execute the graph by ParallelGraph "
<< build_strategy.enable_parallel_graph_; "Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0";
if (member_->use_cuda_) { if (member_->use_cuda_) {
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
...@@ -265,40 +266,25 @@ ParallelExecutor::ParallelExecutor( ...@@ -265,40 +266,25 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs; std::unique_ptr<ir::Graph> graph;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.enable_parallel_graph_) { graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
for (size_t i = 0; i < member_->places_.size(); ++i) { member_->local_scopes_, member_->nranks_,
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( member_->use_cuda_, member_->nccl_ctxs_.get());
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
#else #else
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
main_program, member_->places_, loss_var_name, member_->local_scopes_, member_->local_scopes_, member_->nranks_,
member_->nranks_, member_->use_cuda_); member_->use_cuda_);
graphs.push_back(std::move(graph));
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
for (size_t i = 0; i < graphs.size(); ++i) { graph = member_->PrepareGCAndRefCnts(std::move(graph),
graphs[i] = member_->PrepareGCAndRefCnts( static_cast<size_t>(max_memory_size));
std::move(graphs[i]), static_cast<size_t>(max_memory_size));
}
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back(); var_infos.emplace_back();
...@@ -307,16 +293,15 @@ ParallelExecutor::ParallelExecutor( ...@@ -307,16 +293,15 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
} }
} }
}
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
size_t graph_num = ir::GraphNum(*graphs[0]); size_t graph_num = ir::GraphNum(*graph);
if (graph_num > 1) { if (graph_num > 1) {
LOG(WARNING) LOG(WARNING)
<< "The number of graph should be only one, " << "The number of graph should be only one, "
"but the current graph has " "but the current graph has "
<< ir::GraphNum(*graphs[0]) << ir::GraphNum(*graph)
<< " sub_graphs. If you want to see the nodes of the " << " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' " "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, " "to specify the output dir. NOTES: if you not do training, "
...@@ -325,18 +310,30 @@ ParallelExecutor::ParallelExecutor( ...@@ -325,18 +310,30 @@ ParallelExecutor::ParallelExecutor(
} }
if (build_strategy.enable_parallel_graph_) { if (build_strategy.enable_parallel_graph_) {
auto parallel_graph =
details::SeparateMultiDevicesGraph(member_->places_, std::move(graph));
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Erase(details::kAllOpDescs);
seq_allreduce_pass->Set<const std::vector<OpDesc *>>(
details::kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
for (size_t i = 0; i < parallel_graph.size(); ++i) {
parallel_graph[i] =
seq_allreduce_pass->Apply(std::move(parallel_graph[i]));
}
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs))); std::move(parallel_graph)));
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0]))); std::move(graph)));
} else { } else {
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0]))); std::move(graph)));
} }
} }
...@@ -487,8 +484,8 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -487,8 +484,8 @@ bool ParallelExecutor::EnableParallelGraphExecution(
} }
} }
if (!member_->use_all_reduce_ || !member_->use_cuda_) // if (!member_->use_all_reduce_ || !member_->use_cuda_)
enable_parallel_graph = false; if (!member_->use_all_reduce_) enable_parallel_graph = false;
if (build_strategy.enable_sequential_execution_ || if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
......
...@@ -72,6 +72,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -72,6 +72,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exe.run(startup) exe.run(startup)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.allow_op_delay = allow_op_delay exec_strategy.allow_op_delay = allow_op_delay
exec_strategy.num_threads = 1
if use_fast_executor: if use_fast_executor:
exec_strategy.use_experimental_executor = True exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
...@@ -99,7 +100,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -99,7 +100,7 @@ class TestParallelExecutorBase(unittest.TestCase):
first_loss, = run_executor( first_loss, = run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[loss.name])
for i in range(iter): for _ in range(iter):
run_executor( run_executor(
exe=exe, binary=binary, feed=feed_dict, fetch_list=[]) exe=exe, binary=binary, feed=feed_dict, fetch_list=[])
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import os
os.environ['FLAGS_enable_parallel_graph'] = str(1)
import paddle.fluid.core as core
import os
import paddle.fluid as fluid
from parallel_executor_test_base import TestParallelExecutorBase
def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _init_data(self):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
# simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data()
self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=use_reduce)
def test_simple_fc(self):
# use_cuda
self.check_simple_fc_convergence(True)
def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data()
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_parallel_executor=True)
self.assertAlmostEquals(
np.mean(parallel_first_loss),
single_first_loss,
delta=1e-6, )
self.assertAlmostEquals(
np.mean(parallel_last_loss), single_last_loss, delta=1e-6)
def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册