diff --git a/tools/python/fluidtools/run.py b/tools/python/fluidtools/run.py index 03efa209e2482bbc11cbbfdbe91b5fd1a9f4b159..3503f633e0a4c6ed1a78d8afb887c80886ab3bd7 100644 --- a/tools/python/fluidtools/run.py +++ b/tools/python/fluidtools/run.py @@ -160,7 +160,10 @@ def load_feed_kv(): data = data.reshape(feed_shape).astype("float32") if is_lod: - data = data.reshape((1, *feed_shape)).astype("float32") + data_shape = [1] + for dim in feed_shape: + data_shape.append(dim) + data = data.reshape(data_shape).astype("float32") tensor = fluid.LoDTensor() seq_lens = [len(seq) for seq in data] cur_len = 0 @@ -203,17 +206,32 @@ def get_feed_var_shape(var_name): # return [1, 3, 224, 224] return get_var_shape(var_name) +persistable_cache = [] +# 所有var,全部变成持久化 +def force_all_vars_to_persistable(): + global persistable_cache + for var_name in vars.keys(): + var_name = str(var_name) + v = fluid.framework._get_var(var_name, prog) + persistable = v.persistable + if not persistable: + persistable_cache.append(var_name) + v.persistable = True + +# 恢复持久化属性 +def restore_all_vars_persistable(): + global persistable_cache + for var_name in vars.keys(): + var_name = str(var_name) + v = fluid.framework._get_var(var_name, prog) + persistable = v.persistable + if var_name in persistable_cache: + v.persistable = False + persistable_cache = [] + # 获取var的数据 def get_var_data(var_name, feed_kv=None): - # 强制var为可持久化 - v = fluid.framework._get_var(var_name, prog) - persistable = v.persistable - if not persistable: - v.persistable = True - # outputs = run_model(feed_kv=feed_kv) - output = np.array(fluid.global_scope().find_var(var_name).get_tensor()) - # 恢复var的可持久化属性 - v.persistable = persistable + output = np.array(fluid.global_scope().var(var_name).get_tensor()) return output output_var_cache = {} @@ -223,6 +241,7 @@ def tensor_sample(tensor): else: step = math.floor(len(tensor) / sample_num) step = max(step, 1) + step = int(step) sample = [] for i in range(0, len(tensor), step): sample.append(tensor[i]) @@ -231,6 +250,8 @@ def tensor_sample(tensor): op_cache = {} # 获取每层输出的数据 def save_all_op_output(feed_kv=None): + force_all_vars_to_persistable() + outputs = run_model(feed_kv=feed_kv) if not os.path.exists(output_path): os.mkdir(output_path) ops = prog.current_block().ops @@ -281,6 +302,7 @@ def save_all_op_output(feed_kv=None): except: pass pp_green("all the op outputs are saved into directory 【{}】".format(output_path), 1) + restore_all_vars_persistable() ops = prog.current_block().ops vars = prog.current_block().vars