提交 0d04d6bd 编写于 作者: W wanghaoshuang

Simplify arguments of sensitive API.

上级 5ed67bd9
......@@ -33,12 +33,14 @@ def sensitivity(program,
param_names,
eval_func,
sensitivities_file=None,
step_size=0.2,
max_pruned_times=None):
pruned_ratios=None):
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file)
if pruned_ratios is None:
pruned_ratios = np.arange(0.1, 1, step=0.1)
for name in param_names:
if name not in sensitivities:
size = graph.var(name).shape()[0]
......@@ -49,16 +51,9 @@ def sensitivity(program,
}
baseline = None
for name in sensitivities:
ratio = step_size
pruned_times = 0
while ratio < 1:
if max_pruned_times is not None and pruned_times >= max_pruned_times:
break
ratio = round(ratio, 2)
for ratio in pruned_ratios:
if ratio in sensitivities[name]['pruned_percent']:
_logger.debug('{}, {} has computed.'.format(name, ratio))
ratio += step_size
pruned_times += 1
continue
if baseline is None:
baseline = eval_func(graph.program)
......@@ -92,8 +87,6 @@ def sensitivity(program,
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
ratio += step_size
pruned_times += 1
return sensitivities
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册