未验证 提交 a94a7355 编写于 作者: C chengduo 提交者: GitHub

Refine the GraphNum check (#14144)

* refine GraphCheck
test=develop

* fix ci fail
test=develop
上级 48be9dc3
......@@ -15,8 +15,15 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include <algorithm>
#include <deque>
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <unordered_set>
DEFINE_string(print_sub_graph_dir, "",
"FLAGS_print_sub_graph_dir is used "
"to print the nodes of sub_graphs.");
namespace paddle {
namespace framework {
namespace ir {
......@@ -164,12 +171,15 @@ size_t GraphNum(const Graph &graph) {
graph_nodes.emplace_back(g_nodes);
}
if (VLOG_IS_ON(100)) {
VLOG(100) << "graph_num: " << graph_nodes.size();
for (auto &g_n : graph_nodes) {
VLOG(100) << "graph_nodes: " << g_n.size();
if (g_n.size() < 10) {
if (FLAGS_print_sub_graph_dir.size()) {
if (graph_nodes.size() > 1) {
std::stringstream out;
for (auto &g_n : graph_nodes) {
out << "graph_nodes: " << g_n.size() << "\n";
}
out << "\n\n";
for (auto &g_n : graph_nodes) {
out << "graph_nodes: " << g_n.size();
for (auto &node : g_n) {
out << "\nNode: " << node->Name() << " in [";
for (auto &n : node->inputs) {
......@@ -181,8 +191,12 @@ size_t GraphNum(const Graph &graph) {
}
out << "]";
}
VLOG(100) << out.str();
out << "\n\n\n";
}
std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good());
*fout << out.str();
}
}
......
......@@ -171,8 +171,17 @@ ParallelExecutor::ParallelExecutor(
}
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
size_t graph_num = ir::GraphNum(*graph);
if (graph_num > 1) {
LOG(WARNING)
<< "The number of graph should be only one, "
"but the current graph has "
<< ir::GraphNum(*graph)
<< " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, "
"please don't pass loss_var_name.";
}
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
......
......@@ -116,7 +116,8 @@ def __bootstrap__():
'use_mkldnn', 'use_ngraph', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
"dist_threadpool_size", 'cpu_deterministic', 'eager_delete_tensor_gb',
'allocator_strategy', 'reader_queue_speed_test_mode'
'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir'
]
if os.name != 'nt':
read_env_flags.append('warpctc_dir')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册