提交 afdbc1ff 编写于 作者: W wanghaoshuang

Update groups after pruning parameters of conv layer.

上级 72c800e9
......@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import paddle.fluid as fluid
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import get_logger
__all__ = ["Pruner"]
_logger = get_logger(__name__, level=logging.INFO)
class Pruner():
def __init__(self, criterion="l1_norm"):
......@@ -69,6 +73,10 @@ class Pruner():
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
for op in graph.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
return graph.program
def _prune_filters_by_ratio(self,
......@@ -113,6 +121,8 @@ class Pruner():
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
......@@ -158,6 +168,8 @@ class Pruner():
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册