generate_op.py 28.1 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16
import math
17 18 19 20
import os
from pathlib import Path

import yaml
21
from filters import (
22
    assert_dense_or_sr,
23
    cartesian_prod_mapping,
24 25
    delete_last_underline,
    find_optinal_inputs_name,
26
    get_infer_var_type_func,
J
Jiabin Yang 已提交
27
    to_composite_grad_opmaker_name,
28
    to_input_name,
29 30
    to_int_array_tensor_name,
    to_int_array_tensors_name,
31 32 33 34
    to_op_attr_type,
    to_opmaker_name,
    to_opmaker_name_cstr,
    to_pascal_case,
35
    to_scalar_tensor_name,
36
    to_variable_names,
37
)
38 39
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
C
Charles-hit 已提交
40
from tests_utils import (
41
    is_base_op,
J
Jiabin Yang 已提交
42
    is_composite_op,
43
    is_initializer_list,
X
xiaoguoguo626807 已提交
44
    is_only_composite_op,
45 46
    is_scalar,
    is_vec,
47 48 49
    supports_inplace,
    supports_no_need_buffer,
)
50 51

file_loader = FileSystemLoader(Path(__file__).parent / "templates")
52 53 54 55 56 57 58 59
env = Environment(
    loader=file_loader,
    keep_trailing_newline=True,
    trim_blocks=True,
    lstrip_blocks=True,
    undefined=StrictUndefined,
    extensions=['jinja2.ext.do'],
)
60 61 62
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
63 64 65
env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name
env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name
env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
66 67
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
68
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
J
Jiabin Yang 已提交
69
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
70
env.filters["to_variable_names"] = to_variable_names
71
env.filters["get_infer_var_type_func"] = get_infer_var_type_func
72 73
env.filters["assert_dense_or_sr"] = assert_dense_or_sr
env.filters["find_optinal_inputs_name"] = find_optinal_inputs_name
74
env.tests["base_op"] = is_base_op
J
Jiabin Yang 已提交
75
env.tests["composite_op"] = is_composite_op
X
xiaoguoguo626807 已提交
76
env.tests["only_composite_op"] = is_only_composite_op
77 78 79 80 81 82 83
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer


84 85 86 87 88
def restruct_io(op):
    op["input_dict"] = to_named_dict(op["inputs"])
    op["attr_dict"] = to_named_dict(op["attrs"])
    op["output_dict"] = to_named_dict(op["outputs"])
    return op
89 90


91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
def process_scalar(op_item, scalar_configs):
    scalar_map = {
        'Scalar': 'float',
        'Scalar(float)': 'float',
        'Scalar(int)': 'int',
        'Scalar(int64_t)': 'int64_t',
    }
    if scalar_configs is not None:
        for attr_item in op_item['attrs']:
            if attr_item['name'] in scalar_configs:
                attr_type = attr_item['typename']
                assert (
                    attr_type in scalar_map
                ), f"{op_item['name']}'s scalar in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of Scalar, Scalar(float), Scalar(int) or Scalar(int64_t), but now is {attr_type}."

                scalar_config = scalar_configs[attr_item['name']]
                attr_item['is_support_tensor'] = (
                    True
                    if 'support_tensor' in scalar_config
                    and scalar_config['support_tensor']
                    else False
                )
113 114 115 116 117 118
                attr_item['data_type'] = (
                    scalar_config['data_type']
                    if 'data_type' in scalar_config
                    else scalar_map[attr_type]
                )
                if attr_item['is_support_tensor'] is False:
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
                    attr_item['tensor_name'] = scalar_config['tensor_name']


def process_int_array(op_item, int_array_configs):
    data_type_map = {
        'int': 'std::vector<int>',
        'int64_t': 'std::vector<int64_t>',
    }
    if int_array_configs is not None:
        for attr_item in op_item['attrs']:
            if attr_item['name'] in int_array_configs:
                attr_type = attr_item['typename']
                assert (
                    attr_item['typename'] == "IntArray"
                ), f"{op_item['name']}'s int_array in op_compat.yaml is error, the data_type of {attr_item['name']} is expected to be one of IntArray, but now is {attr_type}."

                int_array_config = int_array_configs[attr_item['name']]
                attr_item['is_support_tensor'] = (
                    True
                    if 'support_tensor' in int_array_config
                    and int_array_config['support_tensor']
                    else False
                )
142 143 144 145 146 147
                attr_item['data_type'] = (
                    data_type_map[int_array_config['data_type']]
                    if 'data_type' in int_array_config
                    else 'std::vector<int64_t>'
                )
                if attr_item['is_support_tensor'] is False:
148 149 150 151 152 153 154 155 156 157 158
                    attr_item['manual_flag'] = True
                    if 'tensor_name' in int_array_config:
                        attr_item['tensor_name'] = int_array_config[
                            'tensor_name'
                        ]
                    if 'tensors_name' in int_array_config:
                        attr_item['tensors_name'] = int_array_config[
                            'tensors_name'
                        ]


159 160 161 162 163 164 165 166 167 168 169
def add_composite_info(ops, backward_ops, backward_op_dict):
    # add backward composite name in forward
    for op in ops + backward_ops:
        if (
            op["backward"] in backward_op_dict
            and "composite" in backward_op_dict[op["backward"]]
        ):
            op["backward_composite"] = op["backward"]
        else:
            op["backward_composite"] = None

X
xiaoguoguo626807 已提交
170 171 172 173 174 175 176 177 178 179
        # add whether only composite
        if (
            op["backward_composite"] is not None
            and "invoke" not in backward_op_dict[op["backward"]]
            and "kernel" not in backward_op_dict[op["backward"]]
        ):
            op["only_backward_composite"] = True
        else:
            op["only_backward_composite"] = False

180 181 182 183 184 185 186 187 188

# add fluid name in ops and backward ops info
def add_fluid_name(dict_list):
    for item in dict_list:
        item["fluid_name"] = item["name"]


# add fluid name of op and params for OpMaker
def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
189
    def get_phi_and_fluid_op_name(op_item):
190
        names = op_item.split('(')
191 192 193 194 195
        if len(names) == 1:
            return names[0].strip(), names[0].strip()
        else:
            return names[0].strip(), names[1].split(')')[0].strip()

196
    def add_op_param_name(op_args, args_alias_map):
197 198
        for item in op_args:
            if item['name'] in args_alias_map:
199 200 201
                item['fluid_name'] = args_alias_map[item['name']]
            else:
                item['fluid_name'] = item['name']
202

203
    def add_grad_args_name(op_args, args_alias_map):
204 205 206 207 208 209 210 211
        for item in op_args:
            if (
                item['name'].endswith('_grad')
                and item['name'][:-5] in args_alias_map
            ):
                args_alias_map[item['name']] = (
                    args_alias_map[item['name'][:-5]] + '_grad'
                )
212 213 214 215 216 217
                item['fluid_name'] = args_alias_map[item['name'][:-5]] + '_grad'
            elif (
                item['name'].endswith('_grad')
                and item['name'][:-5] not in args_alias_map
            ):
                item['fluid_name'] = item['name']
J
Jiabin Yang 已提交
218

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    def get_param_list_alias(param_list, args_map):
        return [
            args_map[param] if param in args_map else param
            for param in param_list
        ]

    def update_common_params_name(
        op_item, args_name_map, scalar_configs, int_array_configs
    ):
        if 'inplace' in op_item and op_item['inplace']:
            inplace_map = {}
            for key, val in op_item['inplace'].items():
                if key in args_map:
                    key = args_map[key]
                if val in args_map:
                    val = args_map[val]
                inplace_map[key] = val
            op_item['inplace'] = inplace_map
        if 'no_need_buffer' in op_item and op_item['no_need_buffer']:
            op_item['no_need_buffer'] = get_param_list_alias(
                op_item['no_need_buffer'], args_map
            )
241 242 243 244 245 246 247 248 249 250
        if 'data_transform' in op_item and op_item['data_transform']:
            data_trans_item = op_item['data_transform']
            if 'skip_transform' in data_trans_item:
                data_trans_item['skip_transform'] = get_param_list_alias(
                    data_trans_item['skip_transform'], args_map
                )
            if 'support_trans_dtype' in data_trans_item:
                data_trans_item['support_trans_dtype'] = get_param_list_alias(
                    data_trans_item['support_trans_dtype'], args_map
                )
251 252 253 254 255 256 257 258 259 260 261 262

        process_scalar(op_item, scalar_configs)
        process_int_array(op_item, int_array_configs)

        if 'invoke' in op_item:
            op_item['invoke']['args'] = [
                args_map[param.strip()]
                if param.strip() in args_map
                else param.strip()
                for param in op_item['invoke']['args'].split(',')
            ]
            return
X
xiaoguoguo626807 已提交
263 264 265
        elif 'composite' in op_item and 'kernel' not in op_item:
            return

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        op_item['infer_meta']['param'] = get_param_list_alias(
            op_item['infer_meta']['param'], args_name_map
        )
        op_item['kernel']['param'] = get_param_list_alias(
            op_item['kernel']['param'], args_name_map
        )
        if op_item['kernel']['data_type']:
            op_item['kernel']['data_type']['candidates'] = get_param_list_alias(
                op_item['kernel']['data_type']['candidates'], args_name_map
            )
        if op_item['kernel']['backend']:
            op_item['kernel']['backend']['candidates'] = get_param_list_alias(
                op_item['kernel']['backend']['candidates'], args_name_map
            )
        if op_item['kernel']['layout']:
            op_item['kernel']['layout']['candidates'] = get_param_list_alias(
                op_item['kernel']['layout']['candidates'], args_name_map
            )

285 286 287 288 289 290 291 292 293
    def add_grad_op_compat_name(grad_op_item, args_name_map):
        add_op_param_name(grad_op_item['inputs'], args_name_map)
        add_op_param_name(grad_op_item['outputs'], args_name_map)
        add_op_param_name(grad_op_item['attrs'], args_name_map)
        add_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
        add_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
        add_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
        add_grad_args_name(grad_op_item['inputs'], args_map)
        add_grad_args_name(grad_op_item['outputs'], args_map)
294 295 296

    for op_args in op_fluid_map_list:
        new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op'])
297
        if new_op_name not in forward_op_dict:
298
            continue
299 300
        forward_op_item = forward_op_dict[new_op_name]
        has_backward = True if forward_op_item['backward'] else False
301
        if has_backward:
302 303 304
            backward_op_item = backward_op_dict[forward_op_item['backward']]
        if new_op_name != op_name:
            forward_op_item['op_name'] = op_name
305

306 307 308 309 310
        # add complex promote infomation
        if "complex_promote" in op_args:
            forward_op_item["complex_promote"] = op_args["complex_promote"]
            if has_backward:
                backward_op_item["complex_promote"] = op_args["complex_promote"]
311 312 313 314 315 316
        scalar_configs = None
        int_array_configs = None
        if 'scalar' in op_args:
            scalar_configs = op_args['scalar']
        if 'int_array' in op_args:
            int_array_configs = op_args['int_array']
317 318 319 320
        if 'extra' in op_args and 'outputs' in op_args['extra']:
            for out_item in forward_op_item['outputs']:
                if out_item['name'] in op_args['extra']['outputs']:
                    out_item['is_extra'] = True
321

322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
        key_set = ['inputs', 'attrs', 'outputs']
        args_map = {}
        for key in key_set:
            if key in op_args:
                args_map.update(op_args[key])
                for args_item in forward_op_item[key]:
                    if args_item['name'] in op_args[key]:
                        if (
                            scalar_configs
                            and args_item['name'] in scalar_configs
                        ):
                            scalar_configs[
                                op_args[key][args_item['name']]
                            ] = scalar_configs[args_item['name']]
                        if (
                            int_array_configs
                            and args_item['name'] in int_array_configs
                        ):
                            int_array_configs[
                                op_args[key][args_item['name']]
                            ] = int_array_configs[args_item['name']]
343 344 345
                        args_item['fluid_name'] = op_args[key][
                            args_item['name']
                        ]
346 347 348 349 350
        update_common_params_name(
            forward_op_item, args_map, scalar_configs, int_array_configs
        )

        if has_backward:
351 352
            # update fluid info in backward
            add_grad_op_compat_name(backward_op_item, args_map)
353 354 355 356 357 358
            update_common_params_name(
                backward_op_item, args_map, scalar_configs, int_array_configs
            )

            if 'backward' not in op_args:
                continue
359

360
            backward_op_list = op_args['backward'].split(',')
361 362 363 364 365 366 367 368
            phi_bw_op_name, bw_op_name = get_phi_and_fluid_op_name(
                backward_op_list[0]
            )
            if (
                forward_op_item["backward_composite"] is not None
                and phi_bw_op_name != bw_op_name
            ):
                forward_op_item["backward_composite"] = bw_op_name
369 370
            forward_op_item['backward'] = bw_op_name
            backward_op_item['op_name'] = bw_op_name
371

372 373
            # for double grad
            if len(backward_op_list) > 1:
374
                (
375
                    phi_double_grad_op_name,
376
                    double_grad_op_name,
377 378
                ) = get_phi_and_fluid_op_name(backward_op_list[1])
                double_grad_item = backward_op_dict[phi_double_grad_op_name]
379 380 381 382 383
                if (
                    backward_op_item["backward_composite"] is not None
                    and phi_double_grad_op_name != double_grad_op_name
                ):
                    backward_op_item["backward_composite"] = double_grad_op_name
384
                backward_op_item['backward'] = double_grad_op_name
385
                double_grad_item['op_name'] = double_grad_op_name
386
                add_grad_op_compat_name(double_grad_item, args_map)
387 388 389 390 391 392
                update_common_params_name(
                    double_grad_item,
                    args_map,
                    scalar_configs,
                    int_array_configs,
                )
393

394 395 396
                # for triple grad
                if len(backward_op_list) > 2:
                    (
397
                        phi_triple_grad_op_name,
398
                        triple_grad_op_name,
399 400
                    ) = get_phi_and_fluid_op_name(backward_op_list[2])
                    triple_grad_item = backward_op_dict[phi_triple_grad_op_name]
401 402 403 404 405 406 407
                    if (
                        double_grad_item["backward_composite"] is not None
                        and phi_triple_grad_op_name != triple_grad_op_name
                    ):
                        double_grad_item[
                            "backward_composite"
                        ] = triple_grad_op_name
408 409
                    double_grad_item['backward'] = triple_grad_op_name
                    triple_grad_item['op_name'] = triple_grad_op_name
410
                    add_grad_op_compat_name(triple_grad_item, args_map)
411 412 413 414 415
                    update_common_params_name(
                        triple_grad_item,
                        args_map,
                        scalar_configs,
                        int_array_configs,
416
                    )
417

418

419 420 421 422 423
def process_invoke_op(forward_op_dict, backward_op_dict):
    for bw_op in backward_op_dict.values():
        if 'invoke' in bw_op:
            invoke_op = bw_op['invoke']['func']
            args_list = bw_op['invoke']['args']
424
            args_index = 0
HappyHeavyRain's avatar
HappyHeavyRain 已提交
425
            # backward invoke forward
426 427
            if invoke_op in forward_op_dict:
                reuse_op = forward_op_dict[invoke_op]
428
                bw_op['invoke']['func'] = reuse_op['op_name']
429 430 431
                bw_op['invoke']['inputs'] = []
                bw_op['invoke']['attrs'] = []
                bw_op['invoke']['outputs'] = []
432
                for input_item in reuse_op['inputs']:
433
                    bw_op['invoke']['inputs'].append(
434
                        {
435
                            'fluid_name': input_item['fluid_name'],
436 437 438 439
                            'name': input_item['name'],
                            'value': args_list[args_index],
                        }
                    )
440
                    args_index = args_index + 1
441 442 443
                bw_fluid_attrs_set = [
                    item['fluid_name'] for item in bw_op['attrs']
                ]
444 445
                for attr in reuse_op['attrs']:
                    if args_index < len(args_list):
446 447
                        attr_value = (
                            f"this->GetAttr(\"{args_list[args_index]}\")"
448
                            if args_list[args_index] in bw_fluid_attrs_set
449 450
                            else args_list[args_index]
                        )
451
                        bw_op['invoke']['attrs'].append(
452 453 454 455 456
                            {
                                'name': attr['name'],
                                'fluid_name': attr['fluid_name'],
                                'value': attr_value,
                            }
457
                        )
458 459 460 461
                        args_index = args_index + 1
                    else:
                        break
                for idx, output_item in enumerate(reuse_op['outputs']):
462
                    bw_op['invoke']['outputs'].append(
463 464
                        {
                            'name': output_item['name'],
465 466
                            'fluid_name': output_item['fluid_name'],
                            'value': bw_op['outputs'][idx]['fluid_name'],
467 468 469 470
                        }
                    )


471
def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
472 473
    for op_comp_map in op_fluid_list:
        if 'drop_empty_grad' in op_comp_map:
474 475
            bw_names = [
                bw_name.split('(')[0].strip()
476
                for bw_name in op_comp_map['backward'].split(',')
477 478
            ]
            for bw_name in bw_names:
HappyHeavyRain's avatar
HappyHeavyRain 已提交
479 480
                # static_ops.yaml and ops.yaml use the common op_compat.yaml
                if bw_name in bw_op_dict:
481
                    for out_grad in op_comp_map['drop_empty_grad']:
HappyHeavyRain's avatar
HappyHeavyRain 已提交
482 483 484 485 486 487 488
                        assert (
                            out_grad in bw_op_dict[bw_name]['output_dict']
                        ), f'''
                            {bw_name} with {out_grad} is not existed in output_dict '''
                        bw_op_dict[bw_name]['output_dict'][out_grad][
                            'drop_empty_grad'
                        ] = False
489 490


491 492 493 494 495 496
def parse_get_expected_kerneltype(
    op_fluid_list: list, fw_op_dict: dict, bw_op_dict: dict
):
    for op_comp_map in op_fluid_list:
        if 'get_expected_kernel_type' in op_comp_map:
            fw_name = op_comp_map['op'].split('(')[0].strip()
497 498 499 500 501 502 503 504 505
            # deal the last underline of function name in op_comp_map['get_expected_kernel_type']
            new_get_expected_kernel_type_func_map = {}
            for (key, value) in op_comp_map['get_expected_kernel_type'].items():
                new_get_expected_kernel_type_func_map[
                    delete_last_underline(key)
                ] = value
            op_comp_map[
                'get_expected_kernel_type'
            ] = new_get_expected_kernel_type_func_map
506 507 508 509 510 511
            if fw_name in op_comp_map['get_expected_kernel_type']:
                # static_ops.yaml and ops.yaml use the common op_compat.yaml
                if fw_name in fw_op_dict:
                    fw_op_dict[fw_name][
                        "get_expected_kernel_type"
                    ] = op_comp_map['get_expected_kernel_type'][fw_name]
512 513 514 515 516 517 518 519 520 521 522 523 524 525
            if "backward" in op_comp_map:
                bw_names = [
                    bw_name.split('(')[0].strip()
                    for bw_name in op_comp_map['backward'].split(',')
                ]
                for bw_name in bw_names:
                    # static_ops.yaml and ops.yaml use the common op_compat.yaml
                    if (
                        bw_name in bw_op_dict
                        and bw_name in op_comp_map['get_expected_kernel_type']
                    ):
                        bw_op_dict[bw_name][
                            "get_expected_kernel_type"
                        ] = op_comp_map['get_expected_kernel_type'][bw_name]
526 527 528 529 530 531 532 533


def parse_keep_signature(
    op_fluid_list: list, fw_op_dict: dict, bw_op_dict: dict
):
    for op_comp_map in op_fluid_list:
        if 'manual_signature' in op_comp_map:
            for op_name in op_comp_map['manual_signature']:
534 535 536 537 538 539 540 541 542
                op_name_without_last_underline = delete_last_underline(op_name)
                if op_name_without_last_underline in fw_op_dict:
                    fw_op_dict[op_name_without_last_underline][
                        "manual_signature"
                    ] = True
                elif op_name_without_last_underline in bw_op_dict:
                    bw_op_dict[op_name_without_last_underline][
                        "manual_signature"
                    ] = True
543 544


545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
def split_ops_list(ops, backward_op_dict, split_num):
    new_ops_list = []
    new_bw_ops_list = []
    list_size = math.ceil(len(ops) / split_num)
    tmp_ops_list = []
    tmp_bw_ops_list = []
    for idx, op in enumerate(ops):
        tmp_ops_list.append(op)
        current_op = op
        while (
            'backward' in current_op
            and current_op['backward'] in backward_op_dict
        ):
            tmp_bw_ops_list.append(backward_op_dict[current_op['backward']])
            current_op = backward_op_dict[current_op['backward']]
        if (idx + 1) % list_size == 0 or idx == len(ops) - 1:
            new_ops_list.append(tmp_ops_list)
            new_bw_ops_list.append(tmp_bw_ops_list)
            tmp_ops_list = []
            tmp_bw_ops_list = []
    return new_ops_list, new_bw_ops_list


568 569 570 571 572 573 574 575 576 577 578 579 580 581
def to_phi_and_fluid_op_name_without_underline(op_item):
    '''
    If the op_name ends with '_', delete the last '_'. For an example, 'sgd_' becomes 'sgd
    '''
    names = op_item.split('(')
    if len(names) == 1:
        op_kernel_name = delete_last_underline(names[0].strip())
        return op_kernel_name
    else:
        op_name = delete_last_underline(names[0].strip())
        kernel_name = delete_last_underline(names[1].split(')')[0].strip())
        return op_name + '(' + kernel_name + ')'


582 583 584 585 586 587 588 589
def main(
    ops_yaml_path,
    backward_yaml_path,
    op_compat_yaml_path,
    op_version_yaml_path,
    output_op_path,
    output_arg_map_path,
):
590
    with open(ops_yaml_path, "rt") as f:
591 592
        ops = yaml.safe_load(f)
        ops = [restruct_io(op) for op in ops]
593
    forward_op_dict = to_named_dict(ops, True)
594
    with open(backward_yaml_path, "rt") as f:
595 596
        backward_ops = yaml.safe_load(f)
        backward_ops = [restruct_io(op) for op in backward_ops]
597
    backward_op_dict = to_named_dict(backward_ops, True)
598
    with open(op_version_yaml_path, "rt") as f:
599 600 601
        op_versions = yaml.safe_load(f)
    # add op version info into op
    for op_version in op_versions:
HappyHeavyRain's avatar
HappyHeavyRain 已提交
602 603
        if op_version['op'] in forward_op_dict:
            forward_op_dict[op_version['op']]['version'] = op_version['version']
604 605

    with open(op_compat_yaml_path, "rt") as f:
606
        op_fluid_map_list = yaml.safe_load(f)
607 608 609 610
        for op_args in op_fluid_map_list:
            op_args["op"] = to_phi_and_fluid_op_name_without_underline(
                op_args["op"]
            )
611

612 613
    for op in ops:
        op['op_name'] = op['name']
614 615 616
        add_fluid_name(op['inputs'])
        add_fluid_name(op['attrs'])
        add_fluid_name(op['outputs'])
617 618
    for bw_op in backward_ops:
        bw_op['op_name'] = bw_op['name']
619 620 621 622 623 624
        add_fluid_name(bw_op['inputs'])
        add_fluid_name(bw_op['attrs'])
        add_fluid_name(bw_op['outputs'])
        add_fluid_name(bw_op['forward']['inputs'])
        add_fluid_name(bw_op['forward']['attrs'])
        add_fluid_name(bw_op['forward']['outputs'])
625 626 627 628 629
        for bw_output in bw_op['outputs']:
            bw_output['drop_empty_grad'] = True

    # deal the drop_empty_grad of bw_op by op_compat.yaml
    parse_drop_empty_grad(op_fluid_map_list, backward_op_dict)
630

631 632 633 634 635 636
    parse_get_expected_kerneltype(
        op_fluid_map_list, forward_op_dict, backward_op_dict
    )

    parse_keep_signature(op_fluid_map_list, forward_op_dict, backward_op_dict)

637
    add_composite_info(ops, backward_ops, backward_op_dict)
J
Jiabin Yang 已提交
638

639
    add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
640 641

    # prepare for invoke case
642
    process_invoke_op(forward_op_dict, backward_op_dict)
643

644 645 646 647 648 649 650
    # fill backward field for an op if another op claims it as forward
    for name, backward_op in backward_op_dict.items():
        forward_name = backward_op["forward"]["name"]
        if forward_name in backward_op_dict:
            forward_op = backward_op_dict[forward_name]
            if forward_op["backward"] is None:
                forward_op["backward"] = name
651

652 653 654 655
    op_dict = {}
    op_dict.update(forward_op_dict)
    op_dict.update(backward_op_dict)
    if len(ops) == 0 and len(backward_ops) == 0:
656 657 658 659 660 661
        if os.path.isfile(output_op_path):
            os.remove(output_op_path)
        if os.path.isfile(output_arg_map_path):
            os.remove(output_arg_map_path)
        return
    op_template = env.get_template('op.c.j2')
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678

    backward_fluid_op_dict = {}
    for bw_op in backward_ops:
        backward_fluid_op_dict[bw_op['op_name']] = bw_op
    output_op_files_num = len(output_op_path)
    new_ops_list, new_bw_ops_list = split_ops_list(
        ops, backward_fluid_op_dict, output_op_files_num
    )
    for idx, output_op_file in enumerate(output_op_path):
        with open(output_op_file, "wt") as f:
            msg = op_template.render(
                ops=new_ops_list[idx],
                backward_ops=new_bw_ops_list[idx],
                op_dict=op_dict,
            )
            f.write(msg)

679 680
    ks_template = env.get_template('ks.c.j2')
    with open(output_arg_map_path, 'wt') as f:
681
        msg = ks_template.render(ops=ops, backward_ops=backward_ops)
682 683 684 685 686
        f.write(msg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
687
        description="Generate operator file from op yaml."
688 689 690 691 692 693 694 695 696 697 698 699 700 701
    )
    parser.add_argument(
        '--ops_yaml_path', type=str, help="parsed ops yaml file."
    )
    parser.add_argument(
        '--backward_yaml_path', type=str, help="parsed backward ops yaml file."
    )
    parser.add_argument(
        '--op_compat_yaml_path', type=str, help="ops args compat yaml file."
    )
    parser.add_argument(
        '--op_version_yaml_path', type=str, help="ops version yaml file."
    )
    parser.add_argument(
702 703 704 705
        "--output_op_path",
        type=str,
        nargs='+',
        help="path to save generated operators.",
706
    )
707 708 709
    parser.add_argument(
        "--output_arg_map_path",
        type=str,
710 711
        help="path to save generated argument mapping functions.",
    )
712 713

    args = parser.parse_args()
714 715 716 717 718 719 720 721
    main(
        args.ops_yaml_path,
        args.backward_yaml_path,
        args.op_compat_yaml_path,
        args.op_version_yaml_path,
        args.output_op_path,
        args.output_arg_map_path,
    )