未验证 提交 65a18a0f 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enhance auto debug script (#1705)

上级 a087196f
...@@ -13,6 +13,7 @@ output_path = "outputs" ...@@ -13,6 +13,7 @@ output_path = "outputs"
diff_threshold = 0.01 diff_threshold = 0.01
is_lod = False is_lod = False
mobile_model_path = "" mobile_model_path = ""
fast_check = False
np.set_printoptions(linewidth=150) np.set_printoptions(linewidth=150)
...@@ -221,6 +222,9 @@ def save_all_op_output(feed_kv=None): ...@@ -221,6 +222,9 @@ def save_all_op_output(feed_kv=None):
if not os.path.exists(output_path): if not os.path.exists(output_path):
os.mkdir(output_path) os.mkdir(output_path)
ops = prog.current_block().ops ops = prog.current_block().ops
fetch_names = []
for fetch in fetches:
fetch_names.append(fetch.name)
for i in range(len(ops)): for i in range(len(ops)):
op = ops[i] op = ops[i]
var_name = None var_name = None
...@@ -230,6 +234,9 @@ def save_all_op_output(feed_kv=None): ...@@ -230,6 +234,9 @@ def save_all_op_output(feed_kv=None):
break break
if "sequence_pool" in var_name: if "sequence_pool" in var_name:
continue continue
if fast_check:
if var_name not in fetch_names:
continue
try: try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
sample = tensor_sample(data) sample = tensor_sample(data)
...@@ -313,10 +320,11 @@ def check_mobile_results(args, fuse, mem_opt): ...@@ -313,10 +320,11 @@ def check_mobile_results(args, fuse, mem_opt):
error_values1 = values1 error_values1 = values1
error_values2 = values2 error_values2 = values2
break break
for name in fetch_names: if error_index == None:
if name not in checked_names: for name in fetch_names:
error_index = -1 if name not in checked_names:
break error_index = -1
break
if error_index == None: if error_index == None:
pp_green("outputs are all correct", 1) pp_green("outputs are all correct", 1)
elif error_index == -1: elif error_index == -1:
...@@ -379,8 +387,9 @@ def main(): ...@@ -379,8 +387,9 @@ def main():
args += " " + str(sample_step) args += " " + str(sample_step)
for var_name in output_var_cache.keys(): for var_name in output_var_cache.keys():
args += " " + var_name args += " " + var_name
check_mobile_results(args, False, False) if not fast_check:
check_mobile_results(args, False, True) check_mobile_results(args, False, False)
check_mobile_results(args, False, True)
check_mobile_results(args, True, False) check_mobile_results(args, True, False)
check_mobile_results(args, True, True) check_mobile_results(args, True, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册