diff --git a/tools/python/fluidtools/run.py b/tools/python/fluidtools/run.py index fef7094e09f97e45f07a762eb6954842fc40c419..5ded7d2d25bf24b1456ae53eacec09916ee2021d 100644 --- a/tools/python/fluidtools/run.py +++ b/tools/python/fluidtools/run.py @@ -13,6 +13,7 @@ output_path = "outputs" diff_threshold = 0.01 is_lod = False mobile_model_path = "" +fast_check = False np.set_printoptions(linewidth=150) @@ -221,6 +222,9 @@ def save_all_op_output(feed_kv=None): if not os.path.exists(output_path): os.mkdir(output_path) ops = prog.current_block().ops + fetch_names = [] + for fetch in fetches: + fetch_names.append(fetch.name) for i in range(len(ops)): op = ops[i] var_name = None @@ -230,6 +234,9 @@ def save_all_op_output(feed_kv=None): break if "sequence_pool" in var_name: continue + if fast_check: + if var_name not in fetch_names: + continue try: data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist() sample = tensor_sample(data) @@ -313,10 +320,11 @@ def check_mobile_results(args, fuse, mem_opt): error_values1 = values1 error_values2 = values2 break - for name in fetch_names: - if name not in checked_names: - error_index = -1 - break + if error_index == None: + for name in fetch_names: + if name not in checked_names: + error_index = -1 + break if error_index == None: pp_green("outputs are all correct", 1) elif error_index == -1: @@ -379,8 +387,9 @@ def main(): args += " " + str(sample_step) for var_name in output_var_cache.keys(): args += " " + var_name - check_mobile_results(args, False, False) - check_mobile_results(args, False, True) + if not fast_check: + check_mobile_results(args, False, False) + check_mobile_results(args, False, True) check_mobile_results(args, True, False) check_mobile_results(args, True, True)