提交 9ac785be 编写于 作者: C chengduoZH

check graph's validation

上级 50104f18
...@@ -272,7 +272,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -272,7 +272,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <utility>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { ...@@ -83,6 +83,74 @@ void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }
std::unique_ptr<SSAGraph> SSAGraphBuilder::BuildAndCheck(
const ProgramDesc &program) final {
std::unique_ptr<SSAGraph> graph = Build(program);
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return std::move(graph);
}
bool SSAGraphBuilder::IsValidGraph(const SSAGraph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) {
pending_vars.insert(var);
if (var->generated_op_ == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
}
}
}
for (auto &var : graph->dep_vars_) {
insert_pending_var(var.get());
}
for (auto &op : graph->ops_) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()});
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
}
}
set.clear();
};
while (!pending_vars.empty()) {
run_all_ops(ready_ops);
if (ready_vars.empty()) {
return false;
}
for (auto ready_var : ready_vars.) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
auto &deps = --pending_ops[op];
if (deps == 0) {
ready_ops.insert(op);
}
}
}
ready_vars.clear();
}
return true;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ class SSAGraphBuilder { ...@@ -31,6 +31,8 @@ class SSAGraphBuilder {
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
std::unique_ptr<SSAGraph> BuildAndCheck(const ProgramDesc &program) final;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
...@@ -48,6 +50,7 @@ class SSAGraphBuilder { ...@@ -48,6 +50,7 @@ class SSAGraphBuilder {
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
bool IsValidGraph(const SSAGraph *graph) const;
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
......
...@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -185,6 +185,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
ready_vars->Push(var); ready_vars->Push(var);
} }
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册