未验证 提交 65a18a0f 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enhance auto debug script (#1705)

上级 a087196f
......@@ -13,6 +13,7 @@ output_path = "outputs"
diff_threshold = 0.01
is_lod = False
mobile_model_path = ""
fast_check = False
np.set_printoptions(linewidth=150)
......@@ -221,6 +222,9 @@ def save_all_op_output(feed_kv=None):
if not os.path.exists(output_path):
os.mkdir(output_path)
ops = prog.current_block().ops
fetch_names = []
for fetch in fetches:
fetch_names.append(fetch.name)
for i in range(len(ops)):
op = ops[i]
var_name = None
......@@ -230,6 +234,9 @@ def save_all_op_output(feed_kv=None):
break
if "sequence_pool" in var_name:
continue
if fast_check:
if var_name not in fetch_names:
continue
try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
sample = tensor_sample(data)
......@@ -313,10 +320,11 @@ def check_mobile_results(args, fuse, mem_opt):
error_values1 = values1
error_values2 = values2
break
for name in fetch_names:
if name not in checked_names:
error_index = -1
break
if error_index == None:
for name in fetch_names:
if name not in checked_names:
error_index = -1
break
if error_index == None:
pp_green("outputs are all correct", 1)
elif error_index == -1:
......@@ -379,8 +387,9 @@ def main():
args += " " + str(sample_step)
for var_name in output_var_cache.keys():
args += " " + var_name
check_mobile_results(args, False, False)
check_mobile_results(args, False, True)
if not fast_check:
check_mobile_results(args, False, False)
check_mobile_results(args, False, True)
check_mobile_results(args, True, False)
check_mobile_results(args, True, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册