hierachical_tree.py 21.4 KB
Newer Older
S
SunAhong1993 已提交
1
# -*- coding:UTF-8 -*-
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os.path as osp
from treelib import Tree
from queue import Queue
S
SunAhong1993 已提交
20
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME
21
from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs
S
SunAhong1993 已提交
22
from x2paddle.core.program import PaddleLayer
S
SunAhong1993 已提交
23
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
S
SunAhong1993 已提交
24 25 26 27 28 29 30

SEPARATOR_IN_SCOPE = "/"


class HierarchicalTree(Tree):
    """ 定义层次树。
    """
31

S
SunAhong1993 已提交
32 33 34 35
    def __init__(self, pd_graph):
        super(HierarchicalTree, self).__init__()
        self.pd_graph = pd_graph
        self.script = pd_graph.script
36
        self.create_node("Module", self.pd_graph.name)  # create root
S
SunAhong1993 已提交
37 38 39 40 41
        self._hierarchical_order = dict()
        self.codes = list()
        self.identifier_idx = dict()
        self.param_tree = PamareterTree()
        self.module_name2count = dict()
S
SunAhong1993 已提交
42
        self.scope_name_list = list()
43

S
SunAhong1993 已提交
44 45
    def insert(self, layer):
        """ 往层次树中插入节点。
46

S
SunAhong1993 已提交
47 48 49 50
        Args:
            layer (PaddleLayer): 需要插入的节点。
        """
        scope_name = layer.scope_name
S
SunAhong1993 已提交
51
        self.scope_name_list.append(scope_name)
S
SunAhong1993 已提交
52 53 54 55 56 57 58
        if scope_name == "":
            if layer.kernel == "prim.tuple" or layer.kernel == "prim.tuple_unpack":
                layer_id = layer.id
                layer_id_list = list()
                for input_layer_id in self.pd_graph.edges_in[layer_id]:
                    layer_id_list.append(int(input_layer_id))
                layer_id_list = list(set(layer_id_list))
59 60
                layer_id_list.sort(reverse=True)

S
SunAhong1993 已提交
61 62 63
                if layer.kernel == "prim.tuple":
                    for i, input_layer_id in enumerate(layer_id_list):
                        input_layer_id_str = str(input_layer_id)
64 65
                        scope_name = self.pd_graph.layers[
                            input_layer_id_str].scope_name
S
SunAhong1993 已提交
66 67 68 69 70
                        if i == 0:
                            min_scope_name = scope_name
                        else:
                            len1 = len(min_scope_name.split("/"))
                            len2 = len(scope_name.split("/"))
S
SunAhong1993 已提交
71 72 73 74
                            if scope_name not in self.scope_name_list:
                                min_scope_name = scope_name
                                continue
                            if len1 > len2:
S
SunAhong1993 已提交
75 76
                                min_scope_name = scope_name
                    if min_scope_name == "":
77 78 79 80 81 82
                        self.create_node(
                            tag=layer.id,
                            identifier="no_scope_" + layer.id,
                            parent=self.pd_graph.name,
                            data=layer)
                        return
S
SunAhong1993 已提交
83 84 85 86 87
                    layer.scope_name = min_scope_name
                    scope_name = min_scope_name
                else:
                    for input_layer_id in layer_id_list:
                        input_layer_id_str = str(input_layer_id)
88 89 90 91 92
                        if self.pd_graph.layers[
                                input_layer_id_str].scope_name != "":
                            scope_name = self.pd_graph.layers[
                                input_layer_id_str].scope_name
                            break
S
SunAhong1993 已提交
93
                    layer.scope_name = scope_name
S
SunAhong1993 已提交
94
            else:
95 96 97 98 99 100
                self.create_node(
                    tag=layer.id,
                    identifier="no_scope_" + layer.id,
                    parent=self.pd_graph.name,
                    data=layer)
                return
S
SunAhong1993 已提交
101 102
        scopes = scope_name.split(SEPARATOR_IN_SCOPE)
        for idx, scope in enumerate(scopes):
103 104
            parent = SEPARATOR_IN_SCOPE.join(scopes[:idx])  #.lower()
            identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1])  #.lower()
S
SunAhong1993 已提交
105 106 107 108 109 110 111 112 113 114 115 116
            if self.contains(identifier):
                if idx != len(scopes) - 1:
                    parent_node = self.parent(identifier)
                    self.move_node(identifier, parent_node.identifier)
                    continue
                else:
                    if self.get_node(identifier).data is None:
                        data = layer
                        if identifier not in self.identifier_idx:
                            self.identifier_idx[identifier] = 0
                        else:
                            self.identifier_idx[identifier] += 1
117 118 119 120 121 122 123
                        identifier_name = identifier + SEPARATOR_IN_SCOPE + str(
                            self.identifier_idx[identifier])
                        self.create_node(
                            tag=scopes[idx],
                            identifier=identifier_name,
                            parent=identifier,
                            data=data)
S
SunAhong1993 已提交
124 125 126 127 128 129 130 131 132 133 134
                        data.scope_name = identifier_name
                        continue
                    else:
                        data = self[identifier].data
                        self[identifier].data = None
                        parent_node = self.parent(identifier)
                        self.move_node(identifier, parent_node.identifier)
                        if identifier not in self.identifier_idx:
                            self.identifier_idx[identifier] = 0
                        else:
                            self.identifier_idx[identifier] += 1
135 136 137 138 139 140
                        self.create_node(
                            tag=scopes[idx],
                            identifier=identifier + SEPARATOR_IN_SCOPE +
                            str(self.identifier_idx[identifier]),
                            parent=identifier,
                            data=data)
S
SunAhong1993 已提交
141 142
                        self.identifier_idx[identifier] += 1
                        data = layer
143 144 145 146 147 148
                        self.create_node(
                            tag=scopes[idx],
                            identifier=identifier + SEPARATOR_IN_SCOPE +
                            str(self.identifier_idx[identifier]),
                            parent=identifier,
                            data=data)
S
SunAhong1993 已提交
149 150 151
                        continue
            if idx == 0 and not self.contains(identifier):
                data = layer if idx == len(scopes) - 1 else None
152 153 154 155 156
                self.create_node(
                    tag=scopes[idx],
                    identifier=identifier,
                    parent=self.pd_graph.name,
                    data=data)
S
SunAhong1993 已提交
157 158 159 160 161 162 163 164 165 166 167
            else:
                if idx == len(scopes) - 1:
                    if parent == "":
                        childs = self.children(self.pd_graph.name)
                        parent = self.pd_graph.name
                    else:
                        childs = self.children(parent)
                    prefix = identifier
                    identifiers = list()
                    for child in childs:
                        child_identifier = child.identifier
168 169
                        if child_identifier.startswith(
                                prefix) and child_identifier != prefix:
S
SunAhong1993 已提交
170 171 172 173 174 175 176 177
                            identifiers.append(child_identifier)
                    if len(identifiers) == 0:
                        identifier = prefix + "_0"
                    else:
                        identifier_ids = list()
                        for id_obj in identifiers:
                            identifier_ids.append(int(id_obj.split("_")[-1]))
                        identifier_ids.sort()
178 179
                        identifier = prefix + "_{}".format(identifier_ids[-1] +
                                                           1)
S
SunAhong1993 已提交
180
                data = layer if idx == len(scopes) - 1 else None
181 182 183 184 185 186
                self.create_node(
                    tag=scopes[idx],
                    identifier=identifier,
                    parent=parent,
                    data=data)

S
SunAhong1993 已提交
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    def update_hierarchical_order(self):
        """ 更新层次排序,使用一个字典存储该信息,
            关键字为当前层次,值为节点名字。
        """
        hierarchical_order = dict()
        queue = Queue()
        queue.put(item=(self.pd_graph.name, 0), block=False)
        while not queue.empty():
            node_name, cur_level = queue.get(block=False)
            node_inst = self[node_name]
            if cur_level not in hierarchical_order:
                hierarchical_order[cur_level] = []
            hierarchical_order[cur_level].append(node_name)
            for successor_name in node_inst.successors(self.identifier):
                queue.put(item=(successor_name, cur_level + 1), block=False)
        self._hierarchical_order = hierarchical_order

    def analyze_attrs_table(self, attrs_table):
        """ 分析属性表格,哪些属性取值不一致。
        """
        diff_attrs_column = list()
        for column in list(attrs_table.columns):
            elements = list(attrs_table.get(column))
            base = elements[0]
            for element in elements[1:]:
                if isinstance(base, str) and "'" not in base:
                    break
                if element != base:
                    diff_attrs_column.append(column)
                    break
        return diff_attrs_column
218 219 220

    def merge_node(self, sub_layers_list, attrs_table, node_name2sub_layers,
                   module_name):
S
SunAhong1993 已提交
221 222 223
        """ 将一个scope的节点合成一个Module(Class),并将对应的Class代码
            放到code字符串中。
        """
224

S
SunAhong1993 已提交
225 226 227 228 229 230
        def get_node_name(sub_layers):
            for k, v in node_name2sub_layers.items():
                if v == sub_layers:
                    node_name = k
                    break
            return node_name
231

S
SunAhong1993 已提交
232 233 234 235 236 237
        sub_layers = sub_layers_list[0]
        node_name = get_node_name(sub_layers)

        sub_layers, _, _ = rename_layers(sub_layers)
        diff_attrs_column = self.analyze_attrs_table(attrs_table)
        if module_name is None:
238
            module_name = node_name.replace("/", "_")  #node_name.split("/")[-1]
S
SunAhong1993 已提交
239 240 241
            module_name = module_name[0].upper() + module_name[1:]
        if module_name in self.module_name2count:
            module_name = module_name + "_0"
242 243 244 245 246 247
        code_str = gen_layer_code(
            self.pd_graph,
            sub_layers,
            module_name,
            different_attrs=diff_attrs_column)

S
SunAhong1993 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261
        self.codes.append(code_str)
        for sub_layers in sub_layers_list:
            inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
            inputs_dict = dict()
            for i, input in enumerate(inputs):
                inputs_dict["input_{}".format(i)] = input
            if module_name in self.module_name2count:
                self.module_name2count[module_name] += 1
            else:
                self.module_name2count[module_name] = 0
            if module_name.lower() in NN_KERNEL_NAME.values():
                mn = module_name.lower() + "__"
            else:
                mn = module_name.lower()
262 263 264
            outputs = [
                "{}/{}".format(mn, self.module_name2count[module_name])
            ] + outputs
S
SunAhong1993 已提交
265
            node_name = get_node_name(sub_layers)
266
            diff_attrs = dict()
S
SunAhong1993 已提交
267 268 269 270 271 272
            for column in diff_attrs_column:
                diff_attrs[column] = attrs_table.get(column).loc[node_name]

            node_name_seg = node_name.split(SEPARATOR_IN_SCOPE)
            node_name_seg[-1] = module_name.lower()
            new_node_name = SEPARATOR_IN_SCOPE.join(node_name_seg)
273 274 275 276 277 278 279 280 281
            new_layer = PaddleLayer(
                id=list(sub_layers.keys())[-1],
                kernel="module",
                inputs=inputs_dict,
                outputs=outputs,
                scope_name=new_node_name,
                module=module_name,
                **diff_attrs)

S
SunAhong1993 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295
            _, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree)
            param_node = PamareterNode(old_name=outputs[0])
            for node in nn_param_nodes:
                param_node.add_child(node)
            self.param_tree.add_node(param_node)

            for i, (layer_id, layer) in enumerate(sub_layers.items()):
                if i == len(sub_layers) - 1:
                    self.pd_graph.layers[layer_id] = new_layer
                else:
                    self.pd_graph.layers.pop(layer_id)

            self.pd_graph.build()
            self[node_name].data = new_layer
296 297 298 299

    def find_subgraph_diff(self, module_name2sub_layers,
                           module_name2sub_identifiers, node_name2sub_layers,
                           name):
S
SunAhong1993 已提交
300 301 302 303
        """ 查找子图的diff,主要是输入参数的diff。
        """
        sub_layers = module_name2sub_layers[name]
        sub_identifiers = module_name2sub_identifiers[name]
304 305 306
        new_sub_layers, new_sub_sequentials, sequentials2attrs_table = distinguish_sequential(
            self.pd_graph, name, sub_layers, sub_identifiers,
            node_name2sub_layers)
S
SunAhong1993 已提交
307 308 309 310 311 312
        module_name2sub_layers.pop(name)
        module_name2sub_identifiers.pop(name)
        for k, v in new_sub_layers.items():
            module_name2sub_layers[k] = v
            module_name2sub_identifiers[k] = new_sub_sequentials[k]
        return sequentials2attrs_table
313

S
SunAhong1993 已提交
314
    def convert_subgraph_to_layer(self):
315
        """
S
SunAhong1993 已提交
316 317 318 319 320 321
            1. 根据_hierarchical_order,从最深的层次开始将
               子图合并成layer(即合成节点)。
            2. 根据参数名新旧对应关系,更新参数名。
        """
        depths = sorted(list(self._hierarchical_order.keys()), reverse=True)
        all_name_old2new = dict()
S
SunAhong1993 已提交
322
        current_module_name_list = list()
S
SunAhong1993 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335
        for depth in depths[1:]:
            # Module的名字与子图的对应关系
            module_name2sub_layers = dict()
            # Module的名字与子图中layer命名的对应关系
            module_name2sub_identifiers = dict()
            # 层次树中包含子树的节点,其节点名与子图对用关系
            node_name2sub_layers = dict()
            for node_name in self._hierarchical_order[depth]:
                node_inst = self[node_name]
                if node_inst.data is None:
                    sub_layers = dict()
                    sub_identifiers = dict()
                    for successor_name in node_inst.successors(self.identifier):
336 337 338 339
                        sub_layers[self[successor_name].data.id] = self[
                            successor_name].data
                        sub_identifiers[self[successor_name].data.id] = self[
                            successor_name].data.scope_name.split("/")[-1]
S
SunAhong1993 已提交
340 341
                    node_name2sub_layers[node_name] = sub_layers
                    node_name_segs = node_name.split("/")
342

S
SunAhong1993 已提交
343 344
                    # 获取Module的名字
                    module = self.script
345
                    is_largest_module = False  # 当前module是否是最外层的Module
S
SunAhong1993 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359
                    for name_id, name in enumerate(node_name_segs):
                        name = name.split("__")[0]
                        if not hasattr(module, name):
                            is_largest_module = True
                            break
                        module = getattr(module, name)
                    if is_largest_module:
                        if name_id == 0:
                            module_name = name
                        else:
                            module_name = "_".join(node_name_segs)
                    else:
                        module_name = module._get_name()
                    if module_name in module_name2sub_layers:
S
SunAhong1993 已提交
360 361 362 363 364 365 366 367 368
                        if len(sub_layers[list(sub_layers.keys())[-1]].outputs) != \
                                    len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs):
                            while module_name in module_name2sub_layers:
                                module_name = module_name + "__tmp"
                                if module_name in module_name2sub_layers and \
                                        len(sub_layers[list(sub_layers.keys())[-1]].outputs) == \
                                        len(module_name2sub_layers[module_name][0][list(module_name2sub_layers[module_name][0].keys())[-1]].outputs):
                                    break
                            if module_name not in module_name2sub_layers:
369 370 371 372
                                module_name2sub_layers[
                                    module_name] = [sub_layers]
                                module_name2sub_identifiers[
                                    module_name] = [sub_identifiers]
S
SunAhong1993 已提交
373
                            else:
374 375 376 377
                                module_name2sub_layers[module_name].append(
                                    sub_layers)
                                module_name2sub_identifiers[module_name].append(
                                    sub_identifiers)
S
SunAhong1993 已提交
378
                        else:
379 380 381 382
                            module_name2sub_layers[module_name].append(
                                sub_layers)
                            module_name2sub_identifiers[module_name].append(
                                sub_identifiers)
S
SunAhong1993 已提交
383 384
                    else:
                        module_name2sub_layers[module_name] = [sub_layers]
385 386
                        module_name2sub_identifiers[
                            module_name] = [sub_identifiers]
S
SunAhong1993 已提交
387 388
            module_names = list(module_name2sub_layers.keys())
            for module_name in module_names:
389 390 391
                sequentials2attrs_table = self.find_subgraph_diff(
                    module_name2sub_layers, module_name2sub_identifiers,
                    node_name2sub_layers, module_name)
S
SunAhong1993 已提交
392 393 394 395 396 397
                for name in sequentials2attrs_table.keys():
                    if name.startswith("Sequential"):
                        # 若Module的名字为Sequential,则以scope_name的名字来命名,在merge_node中实现
                        module_name = None
                    else:
                        module_name = name
S
SunAhong1993 已提交
398 399 400
                        while module_name in current_module_name_list:
                            module_name += "__0"
                    current_module_name_list.append(module_name)
401 402 403
                    self.merge_node(module_name2sub_layers[name],
                                    sequentials2attrs_table[name],
                                    node_name2sub_layers, module_name)
S
SunAhong1993 已提交
404 405 406 407 408

    def update_parameters(self):
        """ 更新参数。
        """
        self.param_tree.traverse()
409 410
        full_old_name_list = copy.deepcopy(
            list(self.pd_graph.parameters.keys()))
S
SunAhong1993 已提交
411 412 413
        for old_name, new_name in self.param_tree.old2new.items():
            for full_old_name in full_old_name_list:
                if full_old_name.startswith("{}.".format(old_name)):
414 415
                    full_new_name = full_old_name.replace(
                        "{}.".format(old_name), "{}.".format(new_name))
S
SunAhong1993 已提交
416 417 418 419 420 421
                    params = self.pd_graph.parameters.pop(full_old_name)
                    self.pd_graph.parameters[full_new_name] = params
                if full_old_name == old_name:
                    full_new_name = full_old_name.replace(old_name, new_name)
                    params = self.pd_graph.parameters.pop(full_old_name)
                    self.pd_graph.parameters[full_new_name] = params
422

S
SunAhong1993 已提交
423 424 425 426 427
    def save_source_files(self, save_dir):
        def gen_main_code():
            input_data_name = ', '.join(self.pd_graph.inputs)
            run_func_list = list()
            run_func_list.append("def main({}):".format(input_data_name))
428 429
            run_func_list.append("    # There are {} inputs.".format(
                len(self.pd_graph.inputs_info)))
S
SunAhong1993 已提交
430
            for k, v in self.pd_graph.inputs_info.items():
431 432 433 434 435 436 437 438 439 440
                run_func_list.append("    # {}: shape-{}, type-{}.".format(k, v[
                    0], v[1]))
            run_func_list.extend([
                "    paddle.disable_static()",
                "    params = paddle.load('{}')".format(
                    osp.join(osp.abspath(save_dir), "model.pdparams")),
                "    model = {}()".format(self.pd_graph.name),
                "    model.set_dict(params)", "    model.eval()",
                "    out = model({})".format(input_data_name), "    return out"
            ])
S
SunAhong1993 已提交
441
            return "\n".join(run_func_list)
442

S
SunAhong1993 已提交
443 444 445 446
        self.update_hierarchical_order()
        self.convert_subgraph_to_layer()
        self.update_parameters()
        import_list = ["import paddle",
S
SunAhong1993 已提交
447
                       "import math",
S
SunAhong1993 已提交
448
                       "from x2paddle.op_mapper.pytorch2paddle " + \
449 450 451
                                 "import pytorch_custom_layer as x2paddle_nn",
                       "",]
        import_str = "\n".join(import_list) + "\n"
S
SunAhong1993 已提交
452 453 454 455 456 457 458 459 460 461
        if not osp.exists(save_dir):
            os.makedirs(save_dir)
        f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w')
        f.write(import_str)
        for code in self.codes:
            f.write(code)
            f.write("\n")
        run_func = gen_main_code()
        f.write(run_func)
        f.close()