提交 68aa5004 编写于 作者: X Xin Pan

polish attrs

上级 9b960330
......@@ -167,7 +167,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return dev_id;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unique_ptr<Graph> graph(new Graph);
for (auto *var : program.Block(0).AllVars()) {
......@@ -301,12 +301,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph.
*/
AddOutputToLeafOps(&result);
std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
ssa_graph->vars_ = std::move(*graph->Erase<GraphVars>("vars"));
ssa_graph->ops_ = std::move(*graph->Erase<GraphOps>("ops"));
ssa_graph->dep_vars_ = std::move(*graph->Erase<GraphDepVars>("dep_vars"));
return std::move(ssa_graph);
return std::move(graph);
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
......
......@@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const BuildStrategy &strategy);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
std::unique_ptr<Graph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const override;
private:
......
......@@ -38,7 +38,7 @@ class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual std::unique_ptr<Graph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
namespace details {
bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
......@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
}
};
for (auto &var_map : graph->vars_) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
......@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const {
}
}
for (auto &var : graph->dep_vars_) {
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
insert_pending_var(var.get());
}
for (auto &op : graph->ops_) {
for (auto &op : graph->Get<GraphOps>("ops")) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
......
......@@ -29,7 +29,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<SSAGraphBuilder>&& builder)
: builder_(std::move(builder)) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
......@@ -39,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return builder_->GetVarDeviceID(var_name);
}
bool IsValidGraph(const SSAGraph* graph) const;
bool IsValidGraph(const Graph* graph) const;
private:
std::unique_ptr<SSAGraphBuilder> builder_;
......
......@@ -21,8 +21,8 @@ namespace framework {
namespace details {
template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
static inline void IterAllVar(const Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
}
}
for (auto &var : graph.dep_vars_) {
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
callback(*var);
}
}
void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
void GraphvizSSAGraphPrinter::Print(const Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
......@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
for (auto &op : graph.Get<GraphOps>("ops")) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......
......@@ -25,12 +25,12 @@ struct SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0;
virtual void Print(const Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public:
void Print(const SSAGraph& graph, std::ostream& sout) const override;
void Print(const Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
......@@ -50,7 +50,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_(std::move(sout)),
stream_ref_(*stream_ptr_) {}
std::unique_ptr<SSAGraph> Build(const ProgramDesc& program) const override {
std::unique_ptr<Graph> Build(const ProgramDesc& program) const override {
auto graph = builder_->Build(program);
printer_->Print(*graph, stream_ref_);
return graph;
......
......@@ -26,71 +26,45 @@ limitations under the License. */
namespace paddle {
namespace framework {
class Graph;
template <typename AttrType>
struct AnyAttr {
public:
explicit AnyAttr(AttrType* attr) : attr_(attr) {}
AttrType& Get() { return *boost::any_cast<AttrType*>(attr_); }
private:
friend Graph;
AttrType* Release() {
released_ = true;
return boost::any_cast<AttrType*>(attr_);
}
void Delete() {
if (!released_) {
delete boost::any_cast<AttrType*>(attr_);
}
}
bool released_ = false;
boost::any attr_;
};
class Graph {
public:
virtual ~Graph() {
for (auto& attr : attrs) {
attr_dels[attr.first]();
for (auto& attr : attrs_) {
attr_dels_[attr.first]();
}
attrs.clear();
attr_dels.clear();
attrs_.clear();
attr_dels_.clear();
}
template <typename AttrType>
AttrType& Get(const std::string& attr_name) {
return boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]).Get();
AttrType& Get(const std::string& attr_name) const {
return *boost::any_cast<AttrType*>(attrs_.at(attr_name));
}
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
AnyAttr<AttrType> any_attr = AnyAttr<AttrType>(attr);
attrs[attr_name] = any_attr;
attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); };
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
template <typename AttrType>
AttrType* Erase(const std::string& attr_name) {
AnyAttr<AttrType> attr_type =
boost::any_cast<AnyAttr<AttrType>>(attrs[attr_name]);
attrs.erase(attr_name);
attr_dels.erase(attr_name);
return attr_type.Release();
AttrType* attr = boost::any_cast<AttrType*>(attrs_[attr_name]);
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
return attr;
}
std::vector<Node*> inputs;
std::vector<Node*> outputs;
std::vector<std::unique_ptr<Node>> nodes;
std::map<std::string, boost::any> attrs;
std::map<std::string, std::function<void(void)>> attr_dels;
private:
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
} // namespace framework
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
......@@ -131,9 +133,16 @@ ParallelExecutor::ParallelExecutor(
}
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(main_program);
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
ssa_graph->ops_ = std::move(graph->Get<details::GraphOps>("ops"));
ssa_graph->dep_vars_ =
std::move(graph->Get<details::GraphDepVars>("dep_vars"));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_->Build(main_program)));
exec_strategy, member_->local_scopes_, places, std::move(ssa_graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册