提交 6971cd19 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enhance auto debug script (#1704)

上级 94e9170a
......@@ -11,7 +11,8 @@ checked_model_path = "checked_model"
feed_path = "feeds"
output_path = "outputs"
diff_threshold = 0.01
is_lod = True
is_lod = False
mobile_model_path = ""
np.set_printoptions(linewidth=150)
......@@ -61,6 +62,10 @@ prog, feeds, fetches = load_model(model_path)
# 强制要求所有张量的形状,在model和params中一致,并重新保存模型
def resave_model(feed_kv):
if len(mobile_model_path) > 0:
pp_green("has set mobile_model_path, stop checking model & params", 1)
sh("cp {}/* {}".format(mobile_model_path, checked_model_path))
return
ops = prog.current_block().ops
vars = prog.current_block().vars
# 强制所有var为可持久化
......@@ -157,7 +162,7 @@ def load_feed_kv():
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += 1
cur_len += l
lod.append(cur_len)
data = data.reshape(feed_shape)
tensor.set(data, fluid.CPUPlace())
......@@ -223,7 +228,7 @@ def save_all_op_output(feed_kv=None):
var_name = name
if "tmp" in name:
break
if "sequence_pool" in name:
if "sequence_pool" in var_name:
continue
try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
......@@ -274,6 +279,10 @@ def check_mobile_results(args, fuse, mem_opt):
error_index = None
error_values1 = None
error_values2 = None
checked_names = []
fetch_names = []
for fetch in fetches:
fetch_names.append(fetch.name)
for index in op_cache:
op_output_var_name, op = op_cache[index]
if mem_opt:
......@@ -299,16 +308,24 @@ def check_mobile_results(args, fuse, mem_opt):
if abs(v1 - v2) > diff_threshold:
error_index = index
break
checked_names.append(op_output_var_name)
if error_index != None:
error_values1 = values1
error_values2 = values2
break
for name in fetch_names:
if name not in checked_names:
error_index = -1
break
if error_index == None:
pp_green("outputs are all correct", 1)
elif error_index == -1:
pp_red("outputs are missing")
else:
error_values1 = np.array(error_values1)
error_values2 = np.array(error_values2)
pp_red("{} op's output is not correct, op's type is {}".format(error_index, op_cache[error_index][1].type), 1)
# pp_red("mobile op is not correct, error occurs at {}th op, op's type is {}")
pp_red("corresponding fluid op is {}th op, op's type is {}".format(error_index, op_cache[error_index][1].type), 1)
pp_red("fluid results are : ", 1)
pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1)
pp_red("paddle mobile results are : ", 1)
......@@ -325,7 +342,9 @@ def main():
feed_kv = load_feed_kv()
pp_yellow(dot + dot + " checking fetch info")
for fetch in fetches:
pp_tab("fetch var name : {}".format(fetch.name), 1)
fetch_name = fetch.name
fetch_shape = get_var_shape(fetch_name)
pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1)
# 预测
pp_yellow(dot + dot + " checking inference")
outputs = run_model(feed_kv=feed_kv)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册