get_sub_model.py 9.3 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
C
ceci3 已提交
17
from paddle.fluid import core
18
from .layers_base import BaseBlock
C
ceci3 已提交
19

C
Chang Xu 已提交
20
__all__ = ['check_search_space']
21

C
Chang Xu 已提交
22 23
DYNAMIC_WEIGHT_OP = [
    'conv2d', 'mul', 'matmul', 'embedding', 'conv2d_transpose',
C
ceci3 已提交
24
    'depthwise_conv2d', 'matmul_v2'
25
]
C
Chang Xu 已提交
26

27 28
CONV_TYPES = [
    'conv2d', 'conv3d', 'conv1d', 'superconv2d', 'supergroupconv2d',
C
ceci3 已提交
29
    'superdepthwiseconv2d', 'matmul_v2'
30
]
C
ceci3 已提交
31

C
Chang Xu 已提交
32 33 34
ALL_WEIGHT_OP = [
    'conv2d', 'mul', 'matmul', 'elementwise_add', 'embedding',
    'conv2d_transpose', 'depthwise_conv2d', 'batch_norm', 'layer_norm',
C
ceci3 已提交
35
    'instance_norm', 'sync_batch_norm', 'matmul_v2'
C
Chang Xu 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
]


def _is_dynamic_weight_op(op, all_weight_op=False):
    if all_weight_op == True:
        weight_ops = ALL_WEIGHT_OP
    else:
        weight_ops = DYNAMIC_WEIGHT_OP
    if op.type() in weight_ops:
        if op.type() in ['mul', 'matmul']:
            for inp in sorted(op.all_inputs()):
                if inp._var.persistable == True:
                    return True
            return False
        return True
    return False

C
ceci3 已提交
53

C
ceci3 已提交
54 55 56 57 58 59 60 61 62 63 64
def get_actual_shape(transform, channel):
    if transform == None:
        channel = int(channel)
    else:
        if isinstance(transform, float):
            channel = int(channel * transform)
        else:
            channel = int(transform)
    return channel


C
ceci3 已提交
65
def _is_depthwise(op):
C
ceci3 已提交
66
    """Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
C
ceci3 已提交
67 68 69
       The shape of input and the shape of output in depthwise conv must be same in superlayer,
       so depthwise op cannot be consider as weight op
    """
C
ceci3 已提交
70 71 72
    #if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
    #    return True
    if 'conv' in op.type():
C
ceci3 已提交
73
        for inp in op.all_inputs():
C
ceci3 已提交
74 75 76
            if inp._var.persistable and (
                    op.attr('groups') == inp._var.shape[0] and
                    op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
C
ceci3 已提交
77 78 79 80
                return True
    return False


81 82 83
def _find_weight_ops(op, graph, weights):
    """ Find the vars come from operators with weight.
    """
C
Chang Xu 已提交
84
    pre_ops = sorted(graph.pre_ops(op))
85
    for pre_op in pre_ops:
C
ceci3 已提交
86 87 88 89 90 91 92
        ### if depthwise conv is one of elementwise's input, 
        ### add it into this same search space
        if _is_depthwise(pre_op):
            for inp in pre_op.all_inputs():
                if inp._var.persistable:
                    weights.append(inp._var.name)

C
Chang Xu 已提交
93
        if _is_dynamic_weight_op(pre_op) and not _is_depthwise(pre_op):
94 95 96 97 98
            for inp in pre_op.all_inputs():
                if inp._var.persistable:
                    weights.append(inp._var.name)
            return weights
        return _find_weight_ops(pre_op, graph, weights)
C
ceci3 已提交
99
    return weights
100 101


C
Chang Xu 已提交
102
def _find_pre_elementwise_op(op, graph):
103 104 105
    """ Find precedors of the elementwise_add operator in the model.
    """
    same_prune_before_elementwise_add = []
C
Chang Xu 已提交
106
    pre_ops = sorted(graph.pre_ops(op))
107
    for pre_op in pre_ops:
C
Chang Xu 已提交
108
        if _is_dynamic_weight_op(pre_op):
109 110 111 112 113 114
            return
        same_prune_before_elementwise_add = _find_weight_ops(
            pre_op, graph, same_prune_before_elementwise_add)
    return same_prune_before_elementwise_add


C
Chang Xu 已提交
115 116 117 118 119 120 121 122 123 124 125
def _is_output_weight_ops(op, graph):
    next_ops = sorted(graph.next_ops(op))
    for next_op in next_ops:
        if op == next_op:
            continue
        if _is_dynamic_weight_op(next_op):
            return False
        return _is_output_weight_ops(next_op, graph)
    return True


126 127 128
def check_search_space(graph):
    """ Find the shortcut in the model and set same config for this situation.
    """
C
Chang Xu 已提交
129
    output_conv = []
130
    same_search_space = []
C
ceci3 已提交
131
    depthwise_conv = []
C
Chang Xu 已提交
132
    fixed_by_input = []
133
    for op in graph.ops():
C
Chang Xu 已提交
134 135 136 137 138 139 140
        # if there is no weight ops after this op, 
        # this op can be seen as an output
        if _is_output_weight_ops(op, graph) and _is_dynamic_weight_op(op):
            for inp in op.all_inputs():
                if inp._var.persistable:
                    output_conv.append(inp._var.name)

C
ceci3 已提交
141
        if op.type() == 'elementwise_add' or op.type() == 'elementwise_mul':
142 143
            inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
            if (not inp1._var.persistable) and (not inp2._var.persistable):
C
Chang Xu 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
                # if one of two vars comes from input, 
                # then the two vars in this elementwise op should be all fixed
                if inp1.inputs() and inp2.inputs():
                    pre_fixed_op_1, pre_fixed_op_2 = [], []
                    pre_fixed_op_1 = _find_weight_ops(inp1.inputs()[0], graph,
                                                      pre_fixed_op_1)
                    pre_fixed_op_2 = _find_weight_ops(inp2.inputs()[0], graph,
                                                      pre_fixed_op_2)
                    if not pre_fixed_op_1:
                        fixed_by_input += pre_fixed_op_2
                    if not pre_fixed_op_2:
                        fixed_by_input += pre_fixed_op_1
                elif (not inp1.inputs() and inp2.inputs()) or (
                        inp1.inputs() and not inp2.inputs()):
                    pre_fixed_op = []
                    inputs = inp1.inputs() if not inp2.inputs(
                    ) else inp2.inputs()
                    pre_fixed_op = _find_weight_ops(inputs[0], graph,
                                                    pre_fixed_op)
                    fixed_by_input += pre_fixed_op

                pre_ele_op = _find_pre_elementwise_op(op, graph)
166 167 168
                if pre_ele_op != None:
                    same_search_space.append(pre_ele_op)

C
ceci3 已提交
169 170 171 172 173
        if _is_depthwise(op):
            for inp in op.all_inputs():
                if inp._var.persistable:
                    depthwise_conv.append(inp._var.name)

174
    if len(same_search_space) == 0:
C
Chang Xu 已提交
175
        return None, [], [], output_conv
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

    same_search_space = sorted([sorted(x) for x in same_search_space])
    final_search_space = []

    if len(same_search_space) >= 1:
        final_search_space = [same_search_space[0]]
        if len(same_search_space) > 1:
            for l in same_search_space[1:]:
                listset = set(l)
                merged = False
                for idx in range(len(final_search_space)):
                    rset = set(final_search_space[idx])
                    if len(listset & rset) != 0:
                        final_search_space[idx] = list(listset | rset)
                        merged = True
                        break
                if not merged:
                    final_search_space.append(l)
C
ceci3 已提交
194 195
    final_search_space = sorted([sorted(x) for x in final_search_space])
    depthwise_conv = sorted(depthwise_conv)
C
Chang Xu 已提交
196
    fixed_by_input = sorted(fixed_by_input)
197

C
Chang Xu 已提交
198
    return (final_search_space, depthwise_conv, fixed_by_input, output_conv)
C
ceci3 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223


def broadcast_search_space(same_search_space, param2key, origin_config):
    """
    Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}

    Args:
        same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
        param2key(dict): the name of layers corresponds to the name of parameter.
        origin_config(dict): the search space which can be searched.
    """
    for per_ss in same_search_space:
        for ss in per_ss[1:]:
            key = param2key[ss]
            pre_key = param2key[per_ss[0]]
            if key in origin_config:
                if 'expand_ratio' in origin_config[pre_key]:
                    origin_config[key].update({
                        'expand_ratio': origin_config[pre_key]['expand_ratio']
                    })
                elif 'channel' in origin_config[pre_key]:
                    origin_config[key].update({
                        'channel': origin_config[pre_key]['channel']
                    })
            else:
C
Chang Xu 已提交
224 225 226 227 228 229 230 231 232 233 234 235
                # if the pre_key is removed from config for some reasons 
                # such as it is fixed by hand or by elementwise op
                if pre_key in origin_config:
                    if 'expand_ratio' in origin_config[pre_key]:
                        origin_config[key] = {
                            'expand_ratio':
                            origin_config[pre_key]['expand_ratio']
                        }
                    elif 'channel' in origin_config[pre_key]:
                        origin_config[key] = {
                            'channel': origin_config[pre_key]['channel']
                        }