提交 c3f6e0e8 编写于 作者: X Xin Pan

add namespace to Graph

上级 0b3465d2
...@@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>("ops").back().get(); auto *op_handle = result->Get<GraphOps>("ops").back().get();
...@@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
// to parameter/gradients before optimizer ops, topo sort is insufficient. ( // to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all // some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes. // optimizer nodes after last backward nodes.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) { // However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> ret = ir::TopologySort(graph); std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
size_t last_backward = 0; size_t last_backward = 0;
std::vector<ir::Node *> optimize_ops; std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret; std::vector<ir::Node *> sorted_ret;
...@@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) { ...@@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
return sorted_ret; return sorted_ret;
} }
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
// Rebuild the graph structure. // Rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = std::move(graph->nodes); auto nodes = std::move(graph->nodes);
...@@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
} }
} }
Graph &result = *graph; ir::Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
...@@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( ...@@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif #endif
} }
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, ...@@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { int dev_id) const {
result->Get<GraphOps>("ops").emplace_back( result->Get<GraphOps>("ops").emplace_back(
...@@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, ...@@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
const std::string &og) const { const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle( result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
...@@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, ...@@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
} }
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
Graph *result, const std::vector<std::string> &datas) const { ir::Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle( result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
...@@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { ...@@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second; return got == var_name_on_devices_.end() ? -1 : got->second;
} }
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { ...@@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
} }
} }
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t num_places) const { size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
...@@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, ...@@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
} }
} }
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const std::string &og, const std::string &og,
int dst_dev_id) const { int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, ...@@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend // Find the first occurence of `prev_op_name` and make current `op` depend
// on it. // on it.
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
...@@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, ...@@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
} }
} }
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
...@@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, ...@@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
......
...@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override; std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
size_t device_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
void CreateRPCOp(Graph *result, ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const;
void CreateDistTrainOp(Graph *result, ir::Node *node) const; void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
...@@ -79,16 +81,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -79,16 +81,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const; const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
void ConnectOp(Graph *result, OpHandleBase *op, void ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(Graph *result, ir::Node *node, void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(Graph *result) const; void CreateScaleLossGradOp(ir::Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
...@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(ir::Node *node) const; int GetOpDeviceID(ir::Node *node) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
void InsertDataBalanceOp(Graph *result, void InsertDataBalanceOp(ir::Graph *result,
const std::vector<std::string> &datas) const; const std::vector<std::string> &datas) const;
void CreateBroadcastOp(Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { ...@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
...@@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -81,7 +81,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var; return var;
} }
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
...@@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, ...@@ -93,7 +93,7 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>("ops")) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
......
...@@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -64,19 +64,19 @@ class SSAGraphBuilder : public ir::Pass {
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
static void PolishGraphToSupportDataHazards(Graph *graph); static void PolishGraphToSupportDataHazards(ir::Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node, static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, static void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, const platform::Place &place, ir::Node *new_node, const platform::Place &place,
size_t place_offset); size_t place_offset);
static void AddOutputToLeafOps(Graph *graph); static void AddOutputToLeafOps(ir::Graph *graph);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<VarHandleBase *> ready_vars;
......
...@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph; return new_graph;
...@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return builder_->GetVarDeviceID(var_name); return builder_->GetVarDeviceID(var_name);
} }
bool IsValidGraph(const Graph* graph) const; bool IsValidGraph(const ir::Graph* graph) const;
private: private:
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
namespace details { namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) { for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
...@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) { ...@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
} }
} }
void GraphvizSSAGraphPrinter::Print(const Graph &graph, void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const { std::ostream &sout) const {
size_t var_id = 0; size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars; std::unordered_map<const VarHandleBase *, size_t> vars;
......
...@@ -25,12 +25,12 @@ namespace details { ...@@ -25,12 +25,12 @@ namespace details {
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
virtual ~SSAGraphPrinter() {} virtual ~SSAGraphPrinter() {}
virtual void Print(const Graph& graph, std::ostream& sout) const = 0; virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
}; };
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const Graph& graph, std::ostream& sout) const override; void Print(const ir::Graph& graph, std::ostream& sout) const override;
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<ir::Graph> Apply(
std::unique_ptr<ir::Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return new_graph; return new_graph;
......
...@@ -21,7 +21,8 @@ namespace framework { ...@@ -21,7 +21,8 @@ namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph) const std::vector<platform::Place> &places,
std::unique_ptr<ir::Graph> &&graph)
: graph_(std::move(graph)), : graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
......
...@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<Graph> &&graph); std::unique_ptr<ir::Graph> &&graph);
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
...@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<Graph> graph_; std::unique_ptr<ir::Graph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
/* /*
namespace { namespace {
void SortHelper( void SortHelper(
...@@ -41,7 +42,7 @@ void SortHelper( ...@@ -41,7 +42,7 @@ void SortHelper(
ret->push_back(node); ret->push_back(node);
} }
std::vector<ir::Node*> TopologySort( std::vector<ir::Node*> TopologySortOperations(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) { const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
...@@ -156,7 +157,7 @@ bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>> ...@@ -156,7 +157,7 @@ bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
return false; return false;
} }
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const std::vector<ir::Node*> &nodes) { const std::vector<ir::Node*> &nodes) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list; std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
...@@ -178,17 +179,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList( ...@@ -178,17 +179,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
return adj_list; return adj_list;
} }
std::vector<ir::Node *> TopologySortOperationFromInToOut( std::vector<ir::Node *> TopologySortOperationsOperationFromInToOut(
const std::vector<std::unique_ptr<ir::Node>> &nodes) { const std::vector<std::unique_ptr<ir::Node>> &nodes) {
std::vector<ir::Node*> tmp; std::vector<ir::Node*> tmp;
for (auto& n : nodes) { for (auto& n : nodes) {
tmp.push_back(n.get()); tmp.push_back(n.get());
} }
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list = std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(tmp); BuildOperationAdjList(tmp);
PADDLE_ENFORCE(!HasCircle(adj_list)); PADDLE_ENFORCE(!HasCircle(adj_list));
std::vector<ir::Node*> ret = TopologySort(adj_list); std::vector<ir::Node*> ret = TopologySortOperations(adj_list);
ir::Node *last_backward = nullptr; ir::Node *last_backward = nullptr;
std::vector<ir::Node *> optimize_ops; std::vector<ir::Node *> optimize_ops;
...@@ -235,5 +236,6 @@ BuildAdjList(tmp); ...@@ -235,5 +236,6 @@ BuildAdjList(tmp);
return ret; return ret;
}*/ }*/
} // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,13 +26,13 @@ limitations under the License. */ ...@@ -26,13 +26,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc& program); explicit Graph(const ProgramDesc &program);
virtual ~Graph() { virtual ~Graph() {
for (auto& attr : attrs_) { for (auto &attr : attrs_) {
attr_dels_[attr.first](); attr_dels_[attr.first]();
} }
attrs_.clear(); attrs_.clear();
...@@ -40,12 +40,12 @@ class Graph { ...@@ -40,12 +40,12 @@ class Graph {
} }
template <typename AttrType> template <typename AttrType>
AttrType& Get(const std::string& attr_name) const { AttrType &Get(const std::string &attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) { void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0); PADDLE_ENFORCE(attrs_.count(attr_name) == 0);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
...@@ -54,17 +54,17 @@ class Graph { ...@@ -54,17 +54,17 @@ class Graph {
}; };
} }
ir::Node* CreateVarNode(VarDesc* var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
nodes.emplace_back(new ir::Node(var_desc)); nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get(); return nodes.back().get();
} }
ir::Node* CreateOpNode(OpDesc* op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
nodes.emplace_back(new ir::Node(op_desc)); nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get(); return nodes.back().get();
} }
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type)); nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get(); return nodes.back().get();
} }
...@@ -73,10 +73,10 @@ class Graph { ...@@ -73,10 +73,10 @@ class Graph {
private: private:
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc& program_; const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
} // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -64,7 +64,7 @@ bool HasCircleHelper( ...@@ -64,7 +64,7 @@ bool HasCircleHelper(
bool HasCircle(const Graph &graph) { bool HasCircle(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list = std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(graph); BuildOperationAdjList(graph);
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace; std::unordered_set<ir::Node *> in_trace;
...@@ -76,9 +76,9 @@ bool HasCircle(const Graph &graph) { ...@@ -76,9 +76,9 @@ bool HasCircle(const Graph &graph) {
return false; return false;
} }
std::vector<ir::Node *> TopologySort(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list = std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildAdjList(graph); BuildOperationAdjList(graph);
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
...@@ -89,7 +89,7 @@ std::vector<ir::Node *> TopologySort(const Graph &graph) { ...@@ -89,7 +89,7 @@ std::vector<ir::Node *> TopologySort(const Graph &graph) {
return ret; return ret;
} }
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) { const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list; std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
......
...@@ -26,9 +26,9 @@ namespace framework { ...@@ -26,9 +26,9 @@ namespace framework {
namespace ir { namespace ir {
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
std::vector<ir::Node *> TopologySort(const Graph &graph); std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -93,7 +93,7 @@ TEST(GraphTest, Basic) { ...@@ -93,7 +93,7 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(proto::VarType::LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<Graph> g(new Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum"); ASSERT_EQ(g->nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
......
...@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
#endif #endif
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph(new Graph(main_program)); std::unique_ptr<ir::Graph> graph(new ir::Graph(main_program));
graph = builder_->Apply(std::move(graph)); graph = builder_->Apply(std::move(graph));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册