提交 5c7a5233 编写于 作者: Y Yu Yang

Add Graphviz output

上级 edfd741e
...@@ -35,6 +35,8 @@ void ComputationOpHandle::RunImpl() { ...@@ -35,6 +35,8 @@ void ComputationOpHandle::RunImpl() {
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
} }
std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -31,6 +31,8 @@ struct ComputationOpHandle : public OpHandleBase {
ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place); platform::Place place);
std::string Name() const override;
protected: protected:
void RunImpl() override; void RunImpl() override;
}; };
......
...@@ -72,6 +72,8 @@ void FetchOpHandle::RunImpl() { ...@@ -72,6 +72,8 @@ void FetchOpHandle::RunImpl() {
} }
} }
std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -38,6 +38,8 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -38,6 +38,8 @@ struct FetchOpHandle : public OpHandleBase {
void WaitAndMergeCPUTensors() const; void WaitAndMergeCPUTensors() const;
std::string Name() const override;
protected: protected:
void RunImpl() override; void RunImpl() override;
}; };
......
...@@ -136,6 +136,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -136,6 +136,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
}
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
} // namespace details } // namespace details
......
...@@ -69,6 +69,8 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -69,6 +69,8 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
} }
} }
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -32,6 +32,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -32,6 +32,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs); const platform::NCCLContextMap &ctxs);
std::string Name() const override;
protected: protected:
void RunImpl() override; void RunImpl() override;
}; };
......
...@@ -33,6 +33,8 @@ struct OpHandleBase { ...@@ -33,6 +33,8 @@ struct OpHandleBase {
std::string DebugString() const; std::string DebugString() const;
virtual std::string Name() const = 0;
virtual ~OpHandleBase(); virtual ~OpHandleBase();
void Run(bool use_event); void Run(bool use_event);
......
...@@ -45,6 +45,8 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -45,6 +45,8 @@ void ScaleLossGradOpHandle::RunImpl() {
#endif #endif
} }
} }
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -32,6 +32,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -32,6 +32,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
std::string Name() const override;
protected: protected:
void RunImpl() override; void RunImpl() override;
}; };
......
...@@ -83,6 +83,64 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -83,6 +83,64 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
var.place_ = place; var.place_ = place;
op_handle->AddOutput(&var); op_handle->AddOutput(&var);
} }
template <typename Callback>
void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(pair2.second);
}
}
}
for (auto &var : graph.dep_vars_) {
callback(*var);
}
}
void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});
size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->inputs_) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}
for (auto out : op->outputs_) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}
sout << "}\n";
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -51,6 +51,8 @@ class SSAGraphBuilder { ...@@ -51,6 +51,8 @@ class SSAGraphBuilder {
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -133,6 +133,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -133,6 +133,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
if (exception_) { if (exception_) {
throw * exception_; throw * exception_;
} }
VLOG(10) << "=============================";
for (auto &op : pending_ops) {
VLOG(10) << op.first->DebugString();
}
// keep waiting the ready variables // keep waiting the ready variables
continue; continue;
} }
......
...@@ -48,7 +48,7 @@ def fc_with_batchnorm(): ...@@ -48,7 +48,7 @@ def fc_with_batchnorm():
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
for _ in xrange(4): for _ in xrange(1):
hidden = fluid.layers.fc( hidden = fluid.layers.fc(
hidden, hidden,
size=200, size=200,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册