未验证 提交 43ac17b2 编写于 作者: W whs 提交者: GitHub

Fix sensitive API. (#30)

上级 d1f3229b
......@@ -81,7 +81,8 @@ class Pruner():
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0)
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(conv_op,pruned_params=pruned_params, visited=visited)
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)
merge_pruned_params = {}
......@@ -94,19 +95,24 @@ class Pruner():
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][pruned_axis])
pruned_idx = np.concatenate(merge_pruned_params[param_name][
pruned_axis])
param = graph.var(param_name)
_logger.debug("{}\t{}\t{}".format(param.name(), pruned_axis, len(pruned_idx)))
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
if not lazy:
_logger.debug("{}\t{}\t{}".format(param.name(
), pruned_axis, len(pruned_idx)))
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
if not only_graph:
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
try:
pruned_param = self._prune_tensor(
np.array(param_t),
......@@ -114,11 +120,11 @@ class Pruner():
pruned_axis=pruned_axis,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
_logger.error("Pruning {}, but get [{}]".format(
param.name(), e))
param_t.set(pruned_param, place)
graph.update_groups_of_conv()
return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, param, ratio, axis):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册