From 6971cd19997ac43d9a88a37494fe462b5310705b Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Wed, 26 Jun 2019 10:38:40 +0800 Subject: [PATCH] enhance auto debug script (#1704) --- tools/python/fluidtools/run.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/tools/python/fluidtools/run.py b/tools/python/fluidtools/run.py index e7ac827b99..fef7094e09 100644 --- a/tools/python/fluidtools/run.py +++ b/tools/python/fluidtools/run.py @@ -11,7 +11,8 @@ checked_model_path = "checked_model" feed_path = "feeds" output_path = "outputs" diff_threshold = 0.01 -is_lod = True +is_lod = False +mobile_model_path = "" np.set_printoptions(linewidth=150) @@ -61,6 +62,10 @@ prog, feeds, fetches = load_model(model_path) # 强制要求所有张量的形状,在model和params中一致,并重新保存模型 def resave_model(feed_kv): + if len(mobile_model_path) > 0: + pp_green("has set mobile_model_path, stop checking model & params", 1) + sh("cp {}/* {}".format(mobile_model_path, checked_model_path)) + return ops = prog.current_block().ops vars = prog.current_block().vars # 强制所有var为可持久化 @@ -157,7 +162,7 @@ def load_feed_kv(): cur_len = 0 lod = [cur_len] for l in seq_lens: - cur_len += 1 + cur_len += l lod.append(cur_len) data = data.reshape(feed_shape) tensor.set(data, fluid.CPUPlace()) @@ -223,7 +228,7 @@ def save_all_op_output(feed_kv=None): var_name = name if "tmp" in name: break - if "sequence_pool" in name: + if "sequence_pool" in var_name: continue try: data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() @@ -274,6 +279,10 @@ def check_mobile_results(args, fuse, mem_opt): error_index = None error_values1 = None error_values2 = None + checked_names = [] + fetch_names = [] + for fetch in fetches: + fetch_names.append(fetch.name) for index in op_cache: op_output_var_name, op = op_cache[index] if mem_opt: @@ -299,16 +308,24 @@ def check_mobile_results(args, fuse, mem_opt): if abs(v1 - v2) > diff_threshold: error_index = index break + checked_names.append(op_output_var_name) if error_index != None: error_values1 = values1 error_values2 = values2 break + for name in fetch_names: + if name not in checked_names: + error_index = -1 + break if error_index == None: pp_green("outputs are all correct", 1) + elif error_index == -1: + pp_red("outputs are missing") else: error_values1 = np.array(error_values1) error_values2 = np.array(error_values2) - pp_red("{} op's output is not correct, op's type is {}".format(error_index, op_cache[error_index][1].type), 1) + # pp_red("mobile op is not correct, error occurs at {}th op, op's type is {}") + pp_red("corresponding fluid op is {}th op, op's type is {}".format(error_index, op_cache[error_index][1].type), 1) pp_red("fluid results are : ", 1) pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1) pp_red("paddle mobile results are : ", 1) @@ -325,7 +342,9 @@ def main(): feed_kv = load_feed_kv() pp_yellow(dot + dot + " checking fetch info") for fetch in fetches: - pp_tab("fetch var name : {}".format(fetch.name), 1) + fetch_name = fetch.name + fetch_shape = get_var_shape(fetch_name) + pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1) # 预测 pp_yellow(dot + dot + " checking inference") outputs = run_model(feed_kv=feed_kv) -- GitLab