提交 0fc582e4 编写于 作者: W wanghaoshuang

Fix sensitive API.

上级 93f32775
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import logging
import pickle
import numpy as np
from ..core import GraphWrapper
from ..common import get_logger
from ..prune import Pruner
_logger = get_logger(__name__, level=logging.INFO)
......@@ -24,14 +28,13 @@ __all__ = ["sensitivity"]
def sensitivity(program,
scope,
place,
param_names,
eval_func,
sensitivities_file=None):
sensitivities_file=None,
step_size=0.2):
graph = GraphWrapper(program)
if sensitivities_file is not None:
assert os.path.exsits(sensitivities_file)
sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names:
......@@ -47,12 +50,12 @@ def sensitivity(program,
ratio = step_size
while ratio < 1:
ratio = round(ratio, 2)
if ratio in sensitivities[param]['pruned_percent']:
if ratio in sensitivities[name]['pruned_percent']:
_logger.debug('{}, {} has computed.'.format(name, ratio))
ratio += step_size
continue
if baseline is None:
baseline = _eval_func(grpah.program, scope)
baseline = eval_func(graph.program, scope)
param_backup = {}
pruner = Pruner()
......@@ -65,7 +68,7 @@ def sensitivity(program,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = _eval_func(pruned_program)
pruned_metric = eval_func(pruned_program, scope)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
......@@ -97,5 +100,12 @@ def _load_sensitivities(sensitivities_file):
sensitivities[param]['pruned_percent'] = [
round(p, 2) for p in sensitivities[param]['pruned_percent']
]
self._format_sensitivities(sensitivities)
return sensitivities
def _save_sensitivities(sensitivities, sensitivities_file):
"""
Save sensitivities into file.
"""
with open(sensitivities_file, 'wb') as f:
pickle.dump(sensitivities, f)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册