diff --git a/test/net/test_net.cpp b/test/net/test_net.cpp index dc0fc887d4d658320cf9c8ff492c4bd3f88b563c..fa2584565b3116c5ee612efe261c9dd340bf4b6d 100644 --- a/test/net/test_net.cpp +++ b/test/net/test_net.cpp @@ -176,19 +176,35 @@ void test(int argc, char *argv[]) { if (out->memory_size() == 0) { continue; } - auto data = out->data(); - std::string sample = ""; - if (!is_sample_step) { - sample_step = len / sample_num; + if (out->type() == type_id()) { + auto data = out->data(); + std::string sample = ""; + if (!is_sample_step) { + sample_step = len / sample_num; + } + if (sample_step <= 0) { + sample_step = 1; + } + for (int i = 0; i < len; i += sample_step) { + sample += " " + std::to_string(data[i]); + } + std::cout << "auto-test" + << " var " << var_name << sample << std::endl; + } else if (out->type() == type_id()) { + auto data = out->data(); + std::string sample = ""; + if (!is_sample_step) { + sample_step = len / sample_num; + } + if (sample_step <= 0) { + sample_step = 1; + } + for (int i = 0; i < len; i += sample_step) { + sample += " " + std::to_string(data[i]); + } + std::cout << "auto-test" + << " var " << var_name << sample << std::endl; } - if (sample_step <= 0) { - sample_step = 1; - } - for (int i = 0; i < len; i += sample_step) { - sample += " " + std::to_string(data[i]); - } - std::cout << "auto-test" - << " var " << var_name << sample << std::endl; } std::cout << std::endl; } diff --git a/tools/python/fluidtools/run.py b/tools/python/fluidtools/run.py index 26526ceef9ff459f80690e58661083caeaf64ded..f68f1576d4c318fec8fdf7c58df17246f1911280 100644 --- a/tools/python/fluidtools/run.py +++ b/tools/python/fluidtools/run.py @@ -382,11 +382,6 @@ def main(): feed_kv = gen_feed_kv() save_feed_kv(feed_kv) feed_kv = load_feed_kv() - pp_yellow(dot + dot + " checking fetch info") - for fetch in fetches: - fetch_name = fetch.name - fetch_shape = get_var_shape(fetch_name) - pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1) # 预测 pp_yellow(dot + dot + " checking inference") outputs = run_model(feed_kv=feed_kv) @@ -397,6 +392,11 @@ def main(): # 输出所有中间结果 pp_yellow(dot + dot + " checking output result of every op") save_all_op_output(feed_kv=feed_kv) + pp_yellow(dot + dot + " checking fetch info") + for fetch in fetches: + fetch_name = fetch.name + fetch_shape = get_var_shape(fetch_name) + pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1) # 开始检查mobile的正确性 print("") print("==================================================")