transpose.py 9.9 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
import copy
import sys


class TransposeOpt:
    def __init__(self):
        self.image_layers = [
            'fluid.layers.conv2d', 'fluid.layers.batch_norm',
            'fluid.layers.conv2d_transpose', 'fluid.layers.resize_nearest',
            'fluid.layers.resize_bilinear', 'fluid.layers.pool2d',
            'fluid.layers.pad2d'
        ]
        self.direct_layers = [
            'fluid.layers.relu', 'fluid.layers.relu6', 'fluid.layers.abs',
            'fluid.layers.sigmoid', 'fluid.layers.exp', 'fluid.layers.rsqrt',
            'fluid.layers.swish_f32', 'fluid.layers.tanh',
            'fluid.layers.softplus', 'fluid.layers.leaky_relu',
            'fluid.layers.floor', 'fluid.layers.erf'
        ]
        self.elementwise_layers = [
            'fluid.layers.elementwise_add', 'fluid.layers.elementwise_sub',
            'fluid.layers.elementwise_mul', 'fluid.layers.elementwise_div'
        ]

    def get_transpose_num(self, graph):
        count = 0
        for layer_id, layer in graph.layers.items():
            if layer.kernel == "fluid.layers.transpose":
                count += 1
        return count

    def strip_direct_layers(self, graph):
        # 构建opt_graph
        # 删除所有direct_layers, 便于对transpose进行优化
        opt_graph = copy.deepcopy(graph)

        remove_layer_ids = set()
        for layer_id, layer in opt_graph.layers.items():
            if layer.kernel in self.direct_layers:
                layer_out = opt_graph.edges_out[layer_id]
                layer_in = opt_graph.edges_in[layer_id]
                if len(layer_out) == 0 or len(layer_in) == 0:
                    continue

                assert len(
                    layer_in
                ) == 1, "There should be only 1 input for direct layers."

                remove_layer_ids.add(layer_id)

        for layer_id in remove_layer_ids:
            opt_graph.del_layer(layer_id)
        return opt_graph

    def run(self, graph):
        optimized_transpose_layers = list()
        modified_layer_attrs = dict()
        modified_parameters = dict()
        scanned_layers = set()
        total_layer_num = len(graph.layers)

        def strip_transpose(_graph):
            layers = copy.deepcopy(_graph.layers)
            for layer_id, layer in layers.items():
                if layer_id in scanned_layers:
                    continue
                scanned_layers.add(layer_id)
                percent = round(len(scanned_layers) / total_layer_num * 100, 2)
                sys.stderr.write("\rOptimize Transpose Layers...{}%".format(
                    percent))

                if layer.kernel != "fluid.layers.transpose":
                    continue
                if layer.attrs["perm"] != [0, 2, 3, 1]:
                    continue

                transpose_layer_ids = list()
                elementwise_layer_ids = list()
                concat_layer_ids = list()
                can_be_optimized = True
                modified_attrs = dict()
                parameter_layers = list()
                parameters = dict()

                for out in _graph.edges_out[layer_id]:
                    if _graph.layers[out].kernel == "fluid.layers.transpose":
                        if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
                            can_be_optimized = False
                            continue
                        transpose_layer_ids.append(out)
                    elif _graph.layers[out].kernel in self.elementwise_layers:
                        elementwise_layer_ids.append(out)
                    elif _graph.layers[out].kernel == "fluid.layers.concat":
                        elementwise_layer_ids.append(out)
                        concat_layer_ids.append(out)
                    else:
                        can_be_optimized = False
                        break

                visited_layers = set()
                while len(elementwise_layer_ids) > 0 and can_be_optimized:
                    current_id = elementwise_layer_ids.pop(0)
                    visited_layers.add(current_id)
                    for out in _graph.edges_out[current_id]:
                        if _graph.layers[
                                out].kernel == "fluid.layers.transpose":
                            if _graph.layers[out].attrs["perm"] != [0, 3, 1, 2]:
                                can_be_optimized = False
                                break
                            if out not in visited_layers:
                                transpose_layer_ids.append(out)
                        elif _graph.layers[
                                out].kernel in self.elementwise_layers:
                            if out not in visited_layers:
                                elementwise_layer_ids.append(out)
                        elif _graph.layers[out].kernel == "fluid.layers.concat":
                            if out not in visited_layers:
                                elementwise_layer_ids.append(out)
                                concat_layer_ids.append(out)
                        else:
                            can_be_optimized = False
                            break

                    all_create_parameter = True
                    for ipt in _graph.edges_in.get(current_id, []):
                        if _graph.layers[
                                ipt].kernel == "fluid.layers.transpose":
                            all_creater_parameter = False
                            if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]:
                                can_be_optimized = False
                                break
                            if ipt not in visited_layers:
                                transpose_layer_ids.append(ipt)
                        elif _graph.layers[
                                ipt].kernel in self.elementwise_layers:
                            all_creater_parameter = False
                            if ipt not in visited_layers:
                                elementwise_layer_ids.append(ipt)
                        elif _graph.layers[ipt].kernel == "fluid.layers.concat":
                            all_creater_parameter = False
                            if ipt not in visited_layers:
                                elementwise_layer_ids.append(ipt)
                                concat_layer_ids.append(ipt)
                        elif _graph.layers[
                                ipt].kernel == "fluid.layers.create_parameter":
                            if ipt not in visited_layers:
                                elementwise_layer_ids.append(ipt)
                                parameter_layers.append(ipt)
                        else:
                            can_be_optimized = False
                            break
                        if all_create_parameter:
                            can_be_optimized = False
                            break

                    if not can_be_optimized:
                        break
                if not can_be_optimized:
                    continue

                concat_layer_ids = list(set(concat_layer_ids))
                for l in concat_layer_ids:
                    axis = _graph.layers[l].attrs.get('axis', 0)
                    _graph.layers[l].attrs['axis'] = [0, 2, 3, 1][axis]
                    modified_attrs[l] = _graph.layers[l].attrs

                parameter_layers = list(set(parameter_layers))
                for l in parameter_layers:
                    for o in _graph.edges_out[l]:
                        if _graph.layers[o].kernel in self.elementwise_layers:
                            axis = _graph.layers[o].attrs.get('axis', -1)
                            _graph.layers[o].attrs['axis'] = [0, 3, 1, 2][axis]
                            modified_attrs[o] = _graph.layers[o].attrs
                        else:
                            can_be_optimized = False
                            break
                        if not can_be_optimized:
                            break
                    s = _graph.layers[l].attrs['shape']
                    p = _graph.parameters[_graph.layers[l].outputs[0]]
                    if len(s) == 4:
                        _graph.layers[l].attrs[
                            'shape'] = [s[0], s[3], s[1], s[2]]
                        modified_attrs[l] = _graph.layers[l].attrs
                        parameters[_graph.layers[l].outputs[0]] = np.transpose(
                            p, (0, 3, 1, 2))
                    elif len(s) == 3:
                        _graph.layers[l].attrs['shape'] = [s[2], s[0], s[1]]
                        modified_attrs[l] = _graph.layers[l].attrs
                        parameters[_graph.layers[l].outputs[0]] = np.transpose(
                            p, (2, 0, 1))

                if not can_be_optimized:
                    continue

                transpose_layer_ids.append(layer_id)
                transpose_layer_ids = list(set(transpose_layer_ids))
                for transpose_layer_id in transpose_layer_ids:
                    _graph.del_layer(transpose_layer_id)
                optimized_transpose_layers.extend(transpose_layer_ids)
                modified_layer_attrs.update(modified_attrs)
                modified_parameters.update(parameters)
                return True
            return False

        before_transpose_num = self.get_transpose_num(graph)

        opt_graph = self.strip_direct_layers(graph)
        total_layer_num = len(opt_graph.layers)
        while strip_transpose(opt_graph):
            pass

        for layer_id in optimized_transpose_layers:
            graph.del_layer(layer_id)

        for layer_id, attrs in modified_layer_attrs.items():
            graph.layers[layer_id].attrs = attrs

        for name, parameter in modified_parameters.items():
            graph.parameters[name] = parameter

        current_transpose_num = self.get_transpose_num(graph)
        print(
            "\nTranspose layers optimized, before: transpose_num={}, after: transpose_num={}".
            format(before_transpose_num, current_transpose_num))