未验证 提交 a087196f 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enhance auto debug script (#1704)

上级 51c8f923
...@@ -11,7 +11,8 @@ checked_model_path = "checked_model" ...@@ -11,7 +11,8 @@ checked_model_path = "checked_model"
feed_path = "feeds" feed_path = "feeds"
output_path = "outputs" output_path = "outputs"
diff_threshold = 0.01 diff_threshold = 0.01
is_lod = True is_lod = False
mobile_model_path = ""
np.set_printoptions(linewidth=150) np.set_printoptions(linewidth=150)
...@@ -61,6 +62,10 @@ prog, feeds, fetches = load_model(model_path) ...@@ -61,6 +62,10 @@ prog, feeds, fetches = load_model(model_path)
# 强制要求所有张量的形状,在model和params中一致,并重新保存模型 # 强制要求所有张量的形状,在model和params中一致,并重新保存模型
def resave_model(feed_kv): def resave_model(feed_kv):
if len(mobile_model_path) > 0:
pp_green("has set mobile_model_path, stop checking model & params", 1)
sh("cp {}/* {}".format(mobile_model_path, checked_model_path))
return
ops = prog.current_block().ops ops = prog.current_block().ops
vars = prog.current_block().vars vars = prog.current_block().vars
# 强制所有var为可持久化 # 强制所有var为可持久化
...@@ -157,7 +162,7 @@ def load_feed_kv(): ...@@ -157,7 +162,7 @@ def load_feed_kv():
cur_len = 0 cur_len = 0
lod = [cur_len] lod = [cur_len]
for l in seq_lens: for l in seq_lens:
cur_len += 1 cur_len += l
lod.append(cur_len) lod.append(cur_len)
data = data.reshape(feed_shape) data = data.reshape(feed_shape)
tensor.set(data, fluid.CPUPlace()) tensor.set(data, fluid.CPUPlace())
...@@ -223,7 +228,7 @@ def save_all_op_output(feed_kv=None): ...@@ -223,7 +228,7 @@ def save_all_op_output(feed_kv=None):
var_name = name var_name = name
if "tmp" in name: if "tmp" in name:
break break
if "sequence_pool" in name: if "sequence_pool" in var_name:
continue continue
try: try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
...@@ -274,6 +279,10 @@ def check_mobile_results(args, fuse, mem_opt): ...@@ -274,6 +279,10 @@ def check_mobile_results(args, fuse, mem_opt):
error_index = None error_index = None
error_values1 = None error_values1 = None
error_values2 = None error_values2 = None
checked_names = []
fetch_names = []
for fetch in fetches:
fetch_names.append(fetch.name)
for index in op_cache: for index in op_cache:
op_output_var_name, op = op_cache[index] op_output_var_name, op = op_cache[index]
if mem_opt: if mem_opt:
...@@ -299,16 +308,24 @@ def check_mobile_results(args, fuse, mem_opt): ...@@ -299,16 +308,24 @@ def check_mobile_results(args, fuse, mem_opt):
if abs(v1 - v2) > diff_threshold: if abs(v1 - v2) > diff_threshold:
error_index = index error_index = index
break break
checked_names.append(op_output_var_name)
if error_index != None: if error_index != None:
error_values1 = values1 error_values1 = values1
error_values2 = values2 error_values2 = values2
break break
for name in fetch_names:
if name not in checked_names:
error_index = -1
break
if error_index == None: if error_index == None:
pp_green("outputs are all correct", 1) pp_green("outputs are all correct", 1)
elif error_index == -1:
pp_red("outputs are missing")
else: else:
error_values1 = np.array(error_values1) error_values1 = np.array(error_values1)
error_values2 = np.array(error_values2) error_values2 = np.array(error_values2)
pp_red("{} op's output is not correct, op's type is {}".format(error_index, op_cache[error_index][1].type), 1) # pp_red("mobile op is not correct, error occurs at {}th op, op's type is {}")
pp_red("corresponding fluid op is {}th op, op's type is {}".format(error_index, op_cache[error_index][1].type), 1)
pp_red("fluid results are : ", 1) pp_red("fluid results are : ", 1)
pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1) pp_red(str(error_values1).replace("\n", "\n" + "\t" * 1), 1)
pp_red("paddle mobile results are : ", 1) pp_red("paddle mobile results are : ", 1)
...@@ -325,7 +342,9 @@ def main(): ...@@ -325,7 +342,9 @@ def main():
feed_kv = load_feed_kv() feed_kv = load_feed_kv()
pp_yellow(dot + dot + " checking fetch info") pp_yellow(dot + dot + " checking fetch info")
for fetch in fetches: for fetch in fetches:
pp_tab("fetch var name : {}".format(fetch.name), 1) fetch_name = fetch.name
fetch_shape = get_var_shape(fetch_name)
pp_tab("fetch var name : {}; fetch var shape : {}".format(fetch_name, fetch_shape), 1)
# 预测 # 预测
pp_yellow(dot + dot + " checking inference") pp_yellow(dot + dot + " checking inference")
outputs = run_model(feed_kv=feed_kv) outputs = run_model(feed_kv=feed_kv)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册