提交 afdbc1ff 编写于 作者: W wanghaoshuang

Update groups after pruning parameters of conv layer.

上级 72c800e9
...@@ -12,13 +12,17 @@ ...@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import get_logger
__all__ = ["Pruner"] __all__ = ["Pruner"]
_logger = get_logger(__name__, level=logging.INFO)
class Pruner(): class Pruner():
def __init__(self, criterion="l1_norm"): def __init__(self, criterion="l1_norm"):
...@@ -69,6 +73,10 @@ class Pruner(): ...@@ -69,6 +73,10 @@ class Pruner():
only_graph=only_graph, only_graph=only_graph,
param_backup=param_backup, param_backup=param_backup,
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
for op in graph.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
return graph.program return graph.program
def _prune_filters_by_ratio(self, def _prune_filters_by_ratio(self,
...@@ -113,6 +121,8 @@ class Pruner(): ...@@ -113,6 +121,8 @@ class Pruner():
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0] new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape) param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name()) self.pruned_list[0].append(param.name())
return pruned_idx return pruned_idx
...@@ -158,6 +168,8 @@ class Pruner(): ...@@ -158,6 +168,8 @@ class Pruner():
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis] new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape) param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name()) self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param): def _forward_search_related_op(self, graph, param):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册