提交 0fc582e4 编写于 作者: W wanghaoshuang

Fix sensitive API.

上级 93f32775
...@@ -12,10 +12,14 @@ ...@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os
import logging import logging
import pickle
import numpy as np import numpy as np
from ..core import GraphWrapper from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
from ..prune import Pruner
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -24,14 +28,13 @@ __all__ = ["sensitivity"] ...@@ -24,14 +28,13 @@ __all__ = ["sensitivity"]
def sensitivity(program, def sensitivity(program,
scope, scope,
place,
param_names, param_names,
eval_func, eval_func,
sensitivities_file=None): sensitivities_file=None,
step_size=0.2):
graph = GraphWrapper(program) graph = GraphWrapper(program)
if sensitivities_file is not None:
assert os.path.exsits(sensitivities_file)
sensitivities = _load_sensitivities(sensitivities_file) sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names: for name in param_names:
...@@ -47,12 +50,12 @@ def sensitivity(program, ...@@ -47,12 +50,12 @@ def sensitivity(program,
ratio = step_size ratio = step_size
while ratio < 1: while ratio < 1:
ratio = round(ratio, 2) ratio = round(ratio, 2)
if ratio in sensitivities[param]['pruned_percent']: if ratio in sensitivities[name]['pruned_percent']:
_logger.debug('{}, {} has computed.'.format(name, ratio)) _logger.debug('{}, {} has computed.'.format(name, ratio))
ratio += step_size ratio += step_size
continue continue
if baseline is None: if baseline is None:
baseline = _eval_func(grpah.program, scope) baseline = eval_func(graph.program, scope)
param_backup = {} param_backup = {}
pruner = Pruner() pruner = Pruner()
...@@ -65,7 +68,7 @@ def sensitivity(program, ...@@ -65,7 +68,7 @@ def sensitivity(program,
lazy=True, lazy=True,
only_graph=False, only_graph=False,
param_backup=param_backup) param_backup=param_backup)
pruned_metric = _eval_func(pruned_program) pruned_metric = eval_func(pruned_program, scope)
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio, _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss)) loss))
...@@ -97,5 +100,12 @@ def _load_sensitivities(sensitivities_file): ...@@ -97,5 +100,12 @@ def _load_sensitivities(sensitivities_file):
sensitivities[param]['pruned_percent'] = [ sensitivities[param]['pruned_percent'] = [
round(p, 2) for p in sensitivities[param]['pruned_percent'] round(p, 2) for p in sensitivities[param]['pruned_percent']
] ]
self._format_sensitivities(sensitivities)
return sensitivities return sensitivities
def _save_sensitivities(sensitivities, sensitivities_file):
"""
Save sensitivities into file.
"""
with open(sensitivities_file, 'wb') as f:
pickle.dump(sensitivities, f)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册