未验证 提交 43ac17b2 编写于 作者: W whs 提交者: GitHub

Fix sensitive API. (#30)

上级 d1f3229b
...@@ -81,7 +81,8 @@ class Pruner(): ...@@ -81,7 +81,8 @@ class Pruner():
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0) pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0)
param = graph.var(param) param = graph.var(param)
conv_op = param.outputs()[0] conv_op = param.outputs()[0]
walker = conv2d_walker(conv_op,pruned_params=pruned_params, visited=visited) walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx) walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)
merge_pruned_params = {} merge_pruned_params = {}
...@@ -94,9 +95,12 @@ class Pruner(): ...@@ -94,9 +95,12 @@ class Pruner():
for param_name in merge_pruned_params: for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]: for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][pruned_axis]) pruned_idx = np.concatenate(merge_pruned_params[param_name][
pruned_axis])
param = graph.var(param_name) param = graph.var(param_name)
_logger.debug("{}\t{}\t{}".format(param.name(), pruned_axis, len(pruned_idx))) if not lazy:
_logger.debug("{}\t{}\t{}".format(param.name(
), pruned_axis, len(pruned_idx)))
if param_shape_backup is not None: if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape()) origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape param_shape_backup[param.name()] = origin_shape
...@@ -105,8 +109,10 @@ class Pruner(): ...@@ -105,8 +109,10 @@ class Pruner():
param.set_shape(new_shape) param.set_shape(new_shape)
if not only_graph: if not only_graph:
param_t = scope.find_var(param.name()).get_tensor() param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup): if param_backup is not None and (
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
try: try:
pruned_param = self._prune_tensor( pruned_param = self._prune_tensor(
np.array(param_t), np.array(param_t),
...@@ -114,11 +120,11 @@ class Pruner(): ...@@ -114,11 +120,11 @@ class Pruner():
pruned_axis=pruned_axis, pruned_axis=pruned_axis,
lazy=lazy) lazy=lazy)
except IndexError as e: except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name( _logger.error("Pruning {}, but get [{}]".format(
), e)) param.name(), e))
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
graph.update_groups_of_conv()
return graph.program, param_backup, param_shape_backup return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, param, ratio, axis): def _cal_pruned_idx(self, param, ratio, axis):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册