未验证 提交 c0b1f2ef 编写于 作者: Y Yanzhan Yang 提交者: GitHub

support intermediate outputs on python 2.7 platform (#1764)

上级 db5f1aa5
...@@ -160,7 +160,10 @@ def load_feed_kv(): ...@@ -160,7 +160,10 @@ def load_feed_kv():
data = data.reshape(feed_shape).astype("float32") data = data.reshape(feed_shape).astype("float32")
if is_lod: if is_lod:
data = data.reshape((1, *feed_shape)).astype("float32") data_shape = [1]
for dim in feed_shape:
data_shape.append(dim)
data = data.reshape(data_shape).astype("float32")
tensor = fluid.LoDTensor() tensor = fluid.LoDTensor()
seq_lens = [len(seq) for seq in data] seq_lens = [len(seq) for seq in data]
cur_len = 0 cur_len = 0
...@@ -203,17 +206,32 @@ def get_feed_var_shape(var_name): ...@@ -203,17 +206,32 @@ def get_feed_var_shape(var_name):
# return [1, 3, 224, 224] # return [1, 3, 224, 224]
return get_var_shape(var_name) return get_var_shape(var_name)
persistable_cache = []
# 所有var,全部变成持久化
def force_all_vars_to_persistable():
global persistable_cache
for var_name in vars.keys():
var_name = str(var_name)
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if not persistable:
persistable_cache.append(var_name)
v.persistable = True
# 恢复持久化属性
def restore_all_vars_persistable():
global persistable_cache
for var_name in vars.keys():
var_name = str(var_name)
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if var_name in persistable_cache:
v.persistable = False
persistable_cache = []
# 获取var的数据 # 获取var的数据
def get_var_data(var_name, feed_kv=None): def get_var_data(var_name, feed_kv=None):
# 强制var为可持久化 output = np.array(fluid.global_scope().var(var_name).get_tensor())
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if not persistable:
v.persistable = True
# outputs = run_model(feed_kv=feed_kv)
output = np.array(fluid.global_scope().find_var(var_name).get_tensor())
# 恢复var的可持久化属性
v.persistable = persistable
return output return output
output_var_cache = {} output_var_cache = {}
...@@ -223,6 +241,7 @@ def tensor_sample(tensor): ...@@ -223,6 +241,7 @@ def tensor_sample(tensor):
else: else:
step = math.floor(len(tensor) / sample_num) step = math.floor(len(tensor) / sample_num)
step = max(step, 1) step = max(step, 1)
step = int(step)
sample = [] sample = []
for i in range(0, len(tensor), step): for i in range(0, len(tensor), step):
sample.append(tensor[i]) sample.append(tensor[i])
...@@ -231,6 +250,8 @@ def tensor_sample(tensor): ...@@ -231,6 +250,8 @@ def tensor_sample(tensor):
op_cache = {} op_cache = {}
# 获取每层输出的数据 # 获取每层输出的数据
def save_all_op_output(feed_kv=None): def save_all_op_output(feed_kv=None):
force_all_vars_to_persistable()
outputs = run_model(feed_kv=feed_kv)
if not os.path.exists(output_path): if not os.path.exists(output_path):
os.mkdir(output_path) os.mkdir(output_path)
ops = prog.current_block().ops ops = prog.current_block().ops
...@@ -281,6 +302,7 @@ def save_all_op_output(feed_kv=None): ...@@ -281,6 +302,7 @@ def save_all_op_output(feed_kv=None):
except: except:
pass pass
pp_green("all the op outputs are saved into directory 【{}】".format(output_path), 1) pp_green("all the op outputs are saved into directory 【{}】".format(output_path), 1)
restore_all_vars_persistable()
ops = prog.current_block().ops ops = prog.current_block().ops
vars = prog.current_block().vars vars = prog.current_block().vars
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册