未验证 提交 c0b1f2ef 编写于 作者: Y Yanzhan Yang 提交者: GitHub

support intermediate outputs on python 2.7 platform (#1764)

上级 db5f1aa5
......@@ -160,7 +160,10 @@ def load_feed_kv():
data = data.reshape(feed_shape).astype("float32")
if is_lod:
data = data.reshape((1, *feed_shape)).astype("float32")
data_shape = [1]
for dim in feed_shape:
data_shape.append(dim)
data = data.reshape(data_shape).astype("float32")
tensor = fluid.LoDTensor()
seq_lens = [len(seq) for seq in data]
cur_len = 0
......@@ -203,17 +206,32 @@ def get_feed_var_shape(var_name):
# return [1, 3, 224, 224]
return get_var_shape(var_name)
persistable_cache = []
# 所有var,全部变成持久化
def force_all_vars_to_persistable():
global persistable_cache
for var_name in vars.keys():
var_name = str(var_name)
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if not persistable:
persistable_cache.append(var_name)
v.persistable = True
# 恢复持久化属性
def restore_all_vars_persistable():
global persistable_cache
for var_name in vars.keys():
var_name = str(var_name)
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if var_name in persistable_cache:
v.persistable = False
persistable_cache = []
# 获取var的数据
def get_var_data(var_name, feed_kv=None):
# 强制var为可持久化
v = fluid.framework._get_var(var_name, prog)
persistable = v.persistable
if not persistable:
v.persistable = True
# outputs = run_model(feed_kv=feed_kv)
output = np.array(fluid.global_scope().find_var(var_name).get_tensor())
# 恢复var的可持久化属性
v.persistable = persistable
output = np.array(fluid.global_scope().var(var_name).get_tensor())
return output
output_var_cache = {}
......@@ -223,6 +241,7 @@ def tensor_sample(tensor):
else:
step = math.floor(len(tensor) / sample_num)
step = max(step, 1)
step = int(step)
sample = []
for i in range(0, len(tensor), step):
sample.append(tensor[i])
......@@ -231,6 +250,8 @@ def tensor_sample(tensor):
op_cache = {}
# 获取每层输出的数据
def save_all_op_output(feed_kv=None):
force_all_vars_to_persistable()
outputs = run_model(feed_kv=feed_kv)
if not os.path.exists(output_path):
os.mkdir(output_path)
ops = prog.current_block().ops
......@@ -281,6 +302,7 @@ def save_all_op_output(feed_kv=None):
except:
pass
pp_green("all the op outputs are saved into directory 【{}】".format(output_path), 1)
restore_all_vars_persistable()
ops = prog.current_block().ops
vars = prog.current_block().vars
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册