提交 68aa5004 编写于 作者: X Xin Pan

polish attrs

上级 9b960330
...@@ -167,7 +167,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -167,7 +167,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id; return dev_id;
} }
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unique_ptr<Graph> graph(new Graph); std::unique_ptr<Graph> graph(new Graph);
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
...@@ -301,12 +301,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -301,12 +301,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
return std::move(graph);
std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
ssa_graph->vars_ = std::move(*graph->Erase<GraphVars>("vars"));
ssa_graph->ops_ = std::move(*graph->Erase<GraphOps>("ops"));
ssa_graph->dep_vars_ = std::move(*graph->Erase<GraphDepVars>("dep_vars"));
return std::move(ssa_graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
......
...@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const BuildStrategy &strategy); const BuildStrategy &strategy);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<Graph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
......
...@@ -38,7 +38,7 @@ class SSAGraphBuilder { ...@@ -38,7 +38,7 @@ class SSAGraphBuilder {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<Graph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<VarHandleBase *> ready_vars;
...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { ...@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
} }
}; };
for (auto &var_map : graph->vars_) { for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair.get());
...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { ...@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
} }
} }
for (auto &var : graph->dep_vars_) { for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
insert_pending_var(var.get()); insert_pending_var(var.get());
} }
for (auto &op : graph->ops_) { for (auto &op : graph->Get<GraphOps>("ops")) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
......
...@@ -29,7 +29,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -29,7 +29,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder) std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {} : builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program); auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get())); PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph; return graph;
...@@ -39,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -39,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return builder_->GetVarDeviceID(var_name); return builder_->GetVarDeviceID(var_name);
} }
bool IsValidGraph(const SSAGraph* graph) const; bool IsValidGraph(const Graph* graph) const;
private: private:
std::unique_ptr<SSAGraphBuilder> builder_; std::unique_ptr<SSAGraphBuilder> builder_;
......
...@@ -21,8 +21,8 @@ namespace framework { ...@@ -21,8 +21,8 @@ namespace framework {
namespace details { namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) { static inline void IterAllVar(const Graph &graph, Callback callback) {
for (auto &each : graph.vars_) { for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
...@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) { ...@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
} }
} }
for (auto &var : graph.dep_vars_) { for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
callback(*var); callback(*var);
} }
} }
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, void GraphvizSSAGraphPrinter::Print(const Graph &graph,
std::ostream &sout) const { std::ostream &sout) const {
size_t var_id = 0; size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars; std::unordered_map<const VarHandleBase *, size_t> vars;
...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, ...@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.ops_) { for (auto &op : graph.Get<GraphOps>("ops")) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -25,12 +25,12 @@ struct SSAGraph; ...@@ -25,12 +25,12 @@ struct SSAGraph;
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
virtual ~SSAGraphPrinter() {} virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
}; };
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const SSAGraph& graph, std::ostream& sout) const override; void Print(const Graph& graph, std::ostream& sout) const override;
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...@@ -50,7 +50,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -50,7 +50,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)), stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {} stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override { std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program); auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_); printer_->Print(*graph, stream_ref_);
return graph; return graph;
......
...@@ -26,71 +26,45 @@ limitations under the License. */ ...@@ -26,71 +26,45 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Graph;
template <typename AttrType>
struct AnyAttr {
public:
explicit AnyAttr(AttrType* attr) : attr_(attr) {}
AttrType& Get() { return *boost::any_cast<AttrType*>(attr_); }
private:
friend Graph;
AttrType* Release() {
released_ = true;
return boost::any_cast<AttrType*>(attr_);
}
void Delete() {
if (!released_) {
delete boost::any_cast<AttrType*>(attr_);
}
}
bool released_ = false;
boost::any attr_;
};
class Graph { class Graph {
public: public:
virtual ~Graph() { virtual ~Graph() {
for (auto& attr : attrs) { for (auto& attr : attrs_) {
attr_dels[attr.first](); attr_dels_[attr.first]();
} }
attrs.clear(); attrs_.clear();
attr_dels.clear(); attr_dels_.clear();
} }
template <typename AttrType> template <typename AttrType>
AttrType& Get(const std::string& attr_name) { AttrType& Get(const std::string& attr_name) const {
return boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]).Get(); return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
} }
template <typename AttrType> template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) { void Set(const std::string& attr_name, AttrType* attr) {
AnyAttr<AttrType> any_attr = AnyAttr<AttrType>(attr); attrs_[attr_name] = attr;
attrs[attr_name] = any_attr; attr_dels_[attr_name] = [attr, attr_name]() {
attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); }; VLOG(3) << "deleting " << attr_name;
delete attr;
};
} }
template <typename AttrType> template <typename AttrType>
AttrType* Erase(const std::string& attr_name) { AttrType* Erase(const std::string& attr_name) {
AnyAttr<AttrType> attr_type = AttrType* attr = boost::any_cast<AttrType*>(attrs_[attr_name]);
boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]); attrs_.erase(attr_name);
attrs.erase(attr_name); attr_dels_.erase(attr_name);
attr_dels.erase(attr_name); return attr;
return attr_type.Release();
} }
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;
std::vector<std::unique_ptr<Node>> nodes; std::vector<std::unique_ptr<Node>> nodes;
std::map<std::string, boost::any> attrs;
std::map<std::string, std::function<void(void)>> attr_dels;
private: private:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
} // namespace framework } // namespace framework
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -131,9 +133,16 @@ ParallelExecutor::ParallelExecutor( ...@@ -131,9 +133,16 @@ ParallelExecutor::ParallelExecutor(
} }
builder_ = builder_factory.Create(); builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(main_program);
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
ssa_graph->ops_ = std::move(graph->Get<details::GraphOps>("ops"));
ssa_graph->dep_vars_ =
std::move(graph->Get<details::GraphDepVars>("dep_vars"));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, exec_strategy, member_->local_scopes_, places, std::move(ssa_graph)));
builder_->Build(main_program)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册