提交 af79b192 编写于 作者: X Xin Pan

add a simple program to graph

上级 7231ef6b
...@@ -94,7 +94,7 @@ else() ...@@ -94,7 +94,7 @@ else()
endif() endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
} }
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { std::unique_ptr<Graph> graph) const {
std::unique_ptr<Graph> graph(new Graph); const ProgramDesc &program = graph->Program();
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars_.emplace(var->Name(), var); all_vars_.emplace(var->Name(), var);
} }
......
...@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<Graph> Build(const ProgramDesc &program) const override; std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
......
...@@ -38,7 +38,7 @@ class SSAGraphBuilder { ...@@ -38,7 +38,7 @@ class SSAGraphBuilder {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<Graph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto graph = builder_->Build(program); auto new_graph = builder_->Build(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto graph = builder_->Build(program); auto new_graph = builder_->Build(std::move(graph));
printer_->Print(*graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return graph; return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -15,5 +15,12 @@ limitations under the License. */ ...@@ -15,5 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
std::unique_ptr<Graph> graph(new Graph(program));
return std::move(graph);
}
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
...@@ -28,6 +29,8 @@ namespace framework { ...@@ -28,6 +29,8 @@ namespace framework {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc& program) : program_(program) {}
virtual ~Graph() { virtual ~Graph() {
for (auto& attr : attrs_) { for (auto& attr : attrs_) {
attr_dels_[attr.first](); attr_dels_[attr.first]();
...@@ -36,6 +39,8 @@ class Graph { ...@@ -36,6 +39,8 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
const ProgramDesc& Program() const { return program_; }
template <typename AttrType> template <typename AttrType>
AttrType& Get(const std::string& attr_name) const { AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name)); return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
...@@ -63,9 +68,12 @@ class Graph { ...@@ -63,9 +68,12 @@ class Graph {
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
private: private:
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -30,11 +30,5 @@ class Pass { ...@@ -30,11 +30,5 @@ class Pass {
} }
}; };
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor(
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(main_program); std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph); std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars")); ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册