提交 af79b192 编写于 作者: X Xin Pan

add a simple program to graph

上级 7231ef6b
......@@ -94,7 +94,7 @@ else()
endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
}
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unique_ptr<Graph> graph(new Graph);
std::unique_ptr<Graph> graph) const {
const ProgramDesc &program = graph->Program();
for (auto *var : program.Block(0).AllVars()) {
all_vars_.emplace(var->Name(), var);
}
......
......@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const BuildStrategy &strategy);
#endif
std::unique_ptr<Graph> Build(const ProgramDesc &program) const override;
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
......
......@@ -38,7 +38,7 @@ class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<Graph> Build(const ProgramDesc &program) const = 0;
virtual std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
std::unique_ptr<Graph> Build(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Build(std::move(graph));
printer_->Print(*new_graph, stream_ref_);
return new_graph;
}
int GetVarDeviceID(const std::string& var_name) const override {
......
......@@ -15,5 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {} // namespace framework
namespace framework {
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc &program) {
std::unique_ptr<Graph> graph(new Graph(program));
return std::move(graph);
}
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
......@@ -28,6 +29,8 @@ namespace framework {
class Graph {
public:
explicit Graph(const ProgramDesc& program) : program_(program) {}
virtual ~Graph() {
for (auto& attr : attrs_) {
attr_dels_[attr.first]();
......@@ -36,6 +39,8 @@ class Graph {
attr_dels_.clear();
}
const ProgramDesc& Program() const { return program_; }
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
......@@ -63,9 +68,12 @@ class Graph {
std::vector<std::unique_ptr<ir::Node>> nodes;
private:
const ProgramDesc& program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program);
} // namespace framework
} // namespace paddle
......@@ -30,11 +30,5 @@ class Pass {
}
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework
} // namespace paddle
......@@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor(
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(main_program);
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册