get_sub_model.py 6.5 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
C
ceci3 已提交
17
from paddle.fluid import core
18
from .layers_base import BaseBlock
C
ceci3 已提交
19

C
Chang Xu 已提交
20
__all__ = ['check_search_space']
21

22 23 24
WEIGHT_OP = [
    'conv2d', 'linear', 'embedding', 'conv2d_transpose', 'depthwise_conv2d'
]
25 26 27 28
CONV_TYPES = [
    'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
    'superdepthwiseconv2d'
]
C
ceci3 已提交
29 30


C
ceci3 已提交
31 32 33 34 35 36 37 38 39 40 41
def get_actual_shape(transform, channel):
    if transform == None:
        channel = int(channel)
    else:
        if isinstance(transform, float):
            channel = int(channel * transform)
        else:
            channel = int(transform)
    return channel


C
ceci3 已提交
42
def _is_depthwise(op):
C
ceci3 已提交
43
    """Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
C
ceci3 已提交
44 45 46
       The shape of input and the shape of output in depthwise conv must be same in superlayer,
       so depthwise op cannot be consider as weight op
    """
C
ceci3 已提交
47 48 49
    #if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
    #    return True
    if 'conv' in op.type():
C
ceci3 已提交
50
        for inp in op.all_inputs():
C
ceci3 已提交
51 52 53
            if inp._var.persistable and (
                    op.attr('groups') == inp._var.shape[0] and
                    op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
C
ceci3 已提交
54 55 56 57
                return True
    return False


58 59 60 61 62
def _find_weight_ops(op, graph, weights):
    """ Find the vars come from operators with weight.
    """
    pre_ops = graph.pre_ops(op)
    for pre_op in pre_ops:
C
ceci3 已提交
63 64 65 66 67 68 69 70
        ### if depthwise conv is one of elementwise's input, 
        ### add it into this same search space
        if _is_depthwise(pre_op):
            for inp in pre_op.all_inputs():
                if inp._var.persistable:
                    weights.append(inp._var.name)

        if pre_op.type() in WEIGHT_OP and not _is_depthwise(pre_op):
71 72 73 74 75
            for inp in pre_op.all_inputs():
                if inp._var.persistable:
                    weights.append(inp._var.name)
            return weights
        return _find_weight_ops(pre_op, graph, weights)
C
ceci3 已提交
76
    return weights
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95


def _find_pre_elementwise_add(op, graph):
    """ Find precedors of the elementwise_add operator in the model.
    """
    same_prune_before_elementwise_add = []
    pre_ops = graph.pre_ops(op)
    for pre_op in pre_ops:
        if pre_op.type() in WEIGHT_OP:
            return
        same_prune_before_elementwise_add = _find_weight_ops(
            pre_op, graph, same_prune_before_elementwise_add)
    return same_prune_before_elementwise_add


def check_search_space(graph):
    """ Find the shortcut in the model and set same config for this situation.
    """
    same_search_space = []
C
ceci3 已提交
96
    depthwise_conv = []
97
    for op in graph.ops():
C
ceci3 已提交
98
        if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
99 100 101 102 103 104
            inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
            if (not inp1._var.persistable) and (not inp2._var.persistable):
                pre_ele_op = _find_pre_elementwise_add(op, graph)
                if pre_ele_op != None:
                    same_search_space.append(pre_ele_op)

C
ceci3 已提交
105 106 107 108 109
        if _is_depthwise(op):
            for inp in op.all_inputs():
                if inp._var.persistable:
                    depthwise_conv.append(inp._var.name)

110
    if len(same_search_space) == 0:
111
        return None, []
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

    same_search_space = sorted([sorted(x) for x in same_search_space])
    final_search_space = []

    if len(same_search_space) >= 1:
        final_search_space = [same_search_space[0]]
        if len(same_search_space) > 1:
            for l in same_search_space[1:]:
                listset = set(l)
                merged = False
                for idx in range(len(final_search_space)):
                    rset = set(final_search_space[idx])
                    if len(listset & rset) != 0:
                        final_search_space[idx] = list(listset | rset)
                        merged = True
                        break
                if not merged:
                    final_search_space.append(l)
C
ceci3 已提交
130 131
    final_search_space = sorted([sorted(x) for x in final_search_space])
    depthwise_conv = sorted(depthwise_conv)
132

C
ceci3 已提交
133
    return (final_search_space, depthwise_conv)
C
ceci3 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166


def broadcast_search_space(same_search_space, param2key, origin_config):
    """
    Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}

    Args:
        same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
        param2key(dict): the name of layers corresponds to the name of parameter.
        origin_config(dict): the search space which can be searched.
    """
    for per_ss in same_search_space:
        for ss in per_ss[1:]:
            key = param2key[ss]
            pre_key = param2key[per_ss[0]]
            if key in origin_config:
                if 'expand_ratio' in origin_config[pre_key]:
                    origin_config[key].update({
                        'expand_ratio': origin_config[pre_key]['expand_ratio']
                    })
                elif 'channel' in origin_config[pre_key]:
                    origin_config[key].update({
                        'channel': origin_config[pre_key]['channel']
                    })
            else:
                if 'expand_ratio' in origin_config[pre_key]:
                    origin_config[key] = {
                        'expand_ratio': origin_config[pre_key]['expand_ratio']
                    }
                elif 'channel' in origin_config[pre_key]:
                    origin_config[key] = {
                        'channel': origin_config[pre_key]['channel']
                    }