generate_op.py 15.6 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path

import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined

22 23 24 25 26 27 28
from filters import (
    to_op_attr_type,
    to_opmaker_name,
    to_opmaker_name_cstr,
    to_pascal_case,
)
from tests import (
29
    is_base_op,
30 31 32 33 34 35
    is_vec,
    is_scalar,
    is_initializer_list,
    supports_inplace,
    supports_no_need_buffer,
)
36
from filters import to_input_name, cartesian_prod_mapping
37 38 39
from parse_utils import to_named_dict

file_loader = FileSystemLoader(Path(__file__).parent / "templates")
40 41 42 43 44 45 46 47
env = Environment(
    loader=file_loader,
    keep_trailing_newline=True,
    trim_blocks=True,
    lstrip_blocks=True,
    undefined=StrictUndefined,
    extensions=['jinja2.ext.do'],
)
48 49 50 51 52
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
53
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
54
env.tests["base_op"] = is_base_op
55 56 57 58 59 60 61
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer


62 63 64 65 66
def restruct_io(op):
    op["input_dict"] = to_named_dict(op["inputs"])
    op["attr_dict"] = to_named_dict(op["attrs"])
    op["output_dict"] = to_named_dict(op["outputs"])
    return op
67 68


69
# replace name of op and params for OpMaker
70 71 72
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
    def get_op_and_op_name(op_item):
        names = op_item.split('(')
73 74 75 76 77
        if len(names) == 1:
            return names[0].strip(), names[0].strip()
        else:
            return names[0].strip(), names[1].split(')')[0].strip()

78
    def update_op_attr_name(attrs, attrs_alias_map):
79 80 81 82
        for attr_item in attrs:
            if attr_item['name'] in attrs_alias_map:
                attr_item['name'] = attrs_alias_map[attr_item['name']]

83 84 85
    for op_args in op_op_map:
        new_op_name, op_name = get_op_and_op_name(op_args['op'])
        if new_op_name not in forward_op_dict:
86
            continue
87 88
        forward_op_item = forward_op_dict[new_op_name]
        has_backward = True if forward_op_item['backward'] else False
89
        if has_backward:
90 91 92
            backward_op_item = backward_op_dict[forward_op_item['backward']]
        if new_op_name != op_name:
            forward_op_item['op_name'] = op_name
93

94 95 96 97 98
        if 'backward' in op_args and has_backward:
            backward_op_list = op_args['backward'].split(',')
            _, bw_op_name = get_op_and_op_name(backward_op_list[0])
            forward_op_item['backward'] = bw_op_name
            backward_op_item['op_name'] = bw_op_name
99

100 101
            # for double grad
            if len(backward_op_list) > 1:
102 103 104 105 106 107
                (
                    new_double_grad_op_name,
                    double_grad_op_name,
                ) = get_op_and_op_name(backward_op_list[1])
                double_grad_item = backward_op_dict[new_double_grad_op_name]
                backward_op_item['backward'] = double_grad_op_name
108
                double_grad_item['op_name'] = double_grad_op_name
109 110 111
                if 'attrs' in op_args:
                    update_op_attr_name(
                        double_grad_item['attrs'], op_args['attrs']
112
                    )
113 114
                    update_op_attr_name(
                        double_grad_item['forward']['attrs'], op_args['attrs']
115
                    )
116 117 118 119

                # for triple grad
                if len(backward_op_list) > 2:
                    (
120
                        new_triple_grad_op_name,
121
                        triple_grad_op_name,
122 123
                    ) = get_op_and_op_name(backward_op_list[2])
                    triple_grad_item = backward_op_dict[new_triple_grad_op_name]
124 125
                    double_grad_item['backward'] = triple_grad_op_name
                    triple_grad_item['op_name'] = triple_grad_op_name
126 127 128
                    if 'attrs' in op_args:
                        update_op_attr_name(
                            triple_grad_item['attrs'], op_args['attrs']
129
                        )
130
                        update_op_attr_name(
131
                            triple_grad_item['forward']['attrs'],
132
                            op_args['attrs'],
133
                        )
134

135 136 137
        key_set = ['inputs', 'attrs', 'outputs']
        args_map = {}
        for key in key_set:
138 139 140 141 142
            if key in op_args:
                args_map.update(op_args[key])
                for args_item in forward_op_item[key]:
                    if args_item['name'] in op_args[key]:
                        args_item['name'] = op_args[key][args_item['name']]
143
                if has_backward:
144 145 146 147
                    for args_item in backward_op_item['forward'][key]:
                        if args_item['name'] in op_args[key]:
                            args_item['name'] = op_args[key][args_item['name']]
        forward_op_item['infer_meta']['param'] = [
148
            args_map[param] if param in args_map else param
149
            for param in forward_op_item['infer_meta']['param']
150
        ]
151
        forward_op_item['kernel']['param'] = [
152
            args_map[param] if param in args_map else param
153
            for param in forward_op_item['kernel']['param']
154
        ]
155 156
        if forward_op_item['kernel']['data_type']:
            forward_op_item['kernel']['data_type']['candidates'] = [
157
                args_map[param] if param in args_map else param
158
                for param in forward_op_item['kernel']['data_type'][
159 160
                    'candidates'
                ]
161
            ]
162 163
        if forward_op_item['kernel']['backend']:
            forward_op_item['kernel']['backend']['candidates'] = [
164
                args_map[param] if param in args_map else param
165
                for param in forward_op_item['kernel']['backend']['candidates']
166
            ]
167 168
        if forward_op_item['kernel']['layout']:
            forward_op_item['kernel']['layout']['candidates'] = [
169
                args_map[param] if param in args_map else param
170
                for param in forward_op_item['kernel']['layout']['candidates']
171
            ]
172
        if forward_op_item['inplace']:
173
            inplace_map = {}
174
            for key, val in forward_op_item['inplace'].items():
175 176 177 178 179
                if key in args_map:
                    key = args_map[key]
                if val in args_map:
                    val = args_map[val]
                inplace_map[key] = val
180
            forward_op_item['inplace'] = inplace_map
181 182

        if has_backward:
183
            for args_item in backward_op_item['inputs']:
184 185
                if args_item['name'] in args_map:
                    args_item['name'] = args_map[args_item['name']]
186 187 188 189 190 191 192
                elif (
                    args_item['name'].endswith('_grad')
                    and args_item['name'][:-5] in args_map
                ):
                    args_map[args_item['name']] = (
                        args_map[args_item['name'][:-5]] + '_grad'
                    )
193
                    args_item['name'] = args_map[args_item['name']]
194
            for args_item in backward_op_item['attrs']:
195 196
                if args_item['name'] in args_map:
                    args_item['name'] = args_map[args_item['name']]
197
            for args_item in backward_op_item['outputs']:
198 199 200 201 202 203 204
                if (
                    args_item['name'].endswith('_grad')
                    and args_item['name'][:-5] in args_map
                ):
                    args_map[args_item['name']] = (
                        args_map[args_item['name'][:-5]] + '_grad'
                    )
205 206
                    args_item['name'] = args_map[args_item['name']]

207 208
            if 'invoke' in backward_op_item:
                backward_op_item['invoke']['args'] = [
209
                    args_map[param.strip()]
210 211
                    if param.strip() in args_map
                    else param.strip()
212
                    for param in backward_op_item['invoke']['args'].split(',')
213 214 215
                ]
                continue

216
            backward_op_item['infer_meta']['param'] = [
217
                args_map[param] if param in args_map else param
218
                for param in backward_op_item['infer_meta']['param']
219
            ]
220
            backward_op_item['kernel']['param'] = [
221
                args_map[param] if param in args_map else param
222
                for param in backward_op_item['kernel']['param']
223
            ]
224 225
            if backward_op_item['kernel']['data_type']:
                backward_op_item['kernel']['data_type']['candidates'] = [
226
                    args_map[param] if param in args_map else param
227
                    for param in backward_op_item['kernel']['data_type'][
228 229
                        'candidates'
                    ]
230
                ]
231 232
            if backward_op_item['kernel']['backend']:
                backward_op_item['kernel']['backend']['candidates'] = [
233
                    args_map[param] if param in args_map else param
234
                    for param in backward_op_item['kernel']['backend'][
235 236
                        'candidates'
                    ]
237
                ]
238 239
            if backward_op_item['kernel']['layout']:
                backward_op_item['kernel']['layout']['candidates'] = [
240
                    args_map[param] if param in args_map else param
241
                    for param in backward_op_item['kernel']['layout'][
242 243
                        'candidates'
                    ]
244
                ]
245 246
            if backward_op_item['no_need_buffer']:
                backward_op_item['no_need_buffer'] = [
247
                    args_map[param] if param in args_map else param
248
                    for param in backward_op_item['no_need_buffer']
249
                ]
250
            if backward_op_item['inplace']:
251
                inplace_map = {}
252
                for key, val in backward_op_item['inplace'].items():
253 254 255 256 257
                    if key in args_map:
                        key = args_map[key]
                    if val in args_map:
                        val = args_map[val]
                    inplace_map[key] = val
258
                backward_op_item['inplace'] = inplace_map
259

260

261 262 263 264 265
def process_invoke_op(forward_op_dict, backward_op_dict):
    for bw_op in backward_op_dict.values():
        if 'invoke' in bw_op:
            invoke_op = bw_op['invoke']['func']
            args_list = bw_op['invoke']['args']
266
            args_index = 0
267 268 269 270 271
            if invoke_op in forward_op_dict:
                reuse_op = forward_op_dict[invoke_op]
                bw_op['invoke']['inputs'] = []
                bw_op['invoke']['attrs'] = []
                bw_op['invoke']['outputs'] = []
272
                for input_item in reuse_op['inputs']:
273
                    bw_op['invoke']['inputs'].append(
274 275 276 277 278
                        {
                            'name': input_item['name'],
                            'value': args_list[args_index],
                        }
                    )
279 280 281
                    args_index = args_index + 1
                for attr in reuse_op['attrs']:
                    if args_index < len(args_list):
282 283
                        attr_value = (
                            f"this->GetAttr(\"{args_list[args_index]}\")"
284
                            if args_list[args_index] in bw_op['attr_dict']
285 286
                            else args_list[args_index]
                        )
287
                        bw_op['invoke']['attrs'].append(
288 289
                            {'name': attr['name'], 'value': attr_value}
                        )
290 291 292 293
                        args_index = args_index + 1
                    else:
                        break
                for idx, output_item in enumerate(reuse_op['outputs']):
294
                    bw_op['invoke']['outputs'].append(
295 296
                        {
                            'name': output_item['name'],
297
                            'value': bw_op['outputs'][idx]['name'],
298 299 300 301 302 303 304 305 306 307 308 309
                        }
                    )


def main(
    ops_yaml_path,
    backward_yaml_path,
    op_compat_yaml_path,
    op_version_yaml_path,
    output_op_path,
    output_arg_map_path,
):
310
    with open(ops_yaml_path, "rt") as f:
311 312 313
        ops = yaml.safe_load(f)
        ops = [restruct_io(op) for op in ops]
    forward_op_dict = to_named_dict(ops)
314 315

    with open(backward_yaml_path, "rt") as f:
316 317 318
        backward_ops = yaml.safe_load(f)
        backward_ops = [restruct_io(op) for op in backward_ops]
    backward_op_dict = to_named_dict(backward_ops)
319 320

    with open(op_version_yaml_path, "rt") as f:
321 322 323 324
        op_versions = yaml.safe_load(f)
    # add op version info into op
    for op_version in op_versions:
        forward_op_dict[op_version['op']]['version'] = op_version['version']
325 326

    with open(op_compat_yaml_path, "rt") as f:
327
        op_op_map = yaml.safe_load(f)
328

329 330 331 332
    for op in ops:
        op['op_name'] = op['name']
    for bw_op in backward_ops:
        bw_op['op_name'] = bw_op['name']
333

334
    replace_compat_name(op_op_map, forward_op_dict, backward_op_dict)
335 336

    # prepare for invoke case
337
    process_invoke_op(forward_op_dict, backward_op_dict)
338

339 340 341 342 343 344 345
    # fill backward field for an op if another op claims it as forward
    for name, backward_op in backward_op_dict.items():
        forward_name = backward_op["forward"]["name"]
        if forward_name in backward_op_dict:
            forward_op = backward_op_dict[forward_name]
            if forward_op["backward"] is None:
                forward_op["backward"] = name
346

347 348 349
    op_dict = {}
    op_dict.update(forward_op_dict)
    op_dict.update(backward_op_dict)
350

351
    if len(ops) == 0 and len(backward_ops) == 0:
352 353 354 355 356 357 358 359
        if os.path.isfile(output_op_path):
            os.remove(output_op_path)
        if os.path.isfile(output_arg_map_path):
            os.remove(output_arg_map_path)
        return

    op_template = env.get_template('op.c.j2')
    with open(output_op_path, "wt") as f:
360
        msg = op_template.render(
361
            ops=ops, backward_ops=backward_ops, op_dict=op_dict
362
        )
363 364 365 366
        f.write(msg)

    ks_template = env.get_template('ks.c.j2')
    with open(output_arg_map_path, 'wt') as f:
367
        msg = ks_template.render(ops=ops, backward_ops=backward_ops)
368 369 370 371 372
        f.write(msg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
373
        description="Generate operator file from op yaml."
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
    )
    parser.add_argument(
        '--ops_yaml_path', type=str, help="parsed ops yaml file."
    )
    parser.add_argument(
        '--backward_yaml_path', type=str, help="parsed backward ops yaml file."
    )
    parser.add_argument(
        '--op_compat_yaml_path', type=str, help="ops args compat yaml file."
    )
    parser.add_argument(
        '--op_version_yaml_path', type=str, help="ops version yaml file."
    )
    parser.add_argument(
        "--output_op_path", type=str, help="path to save generated operators."
    )
390 391 392
    parser.add_argument(
        "--output_arg_map_path",
        type=str,
393 394
        help="path to save generated argument mapping functions.",
    )
395 396

    args = parser.parse_args()
397 398 399 400 401 402 403 404
    main(
        args.ops_yaml_path,
        args.backward_yaml_path,
        args.op_compat_yaml_path,
        args.op_version_yaml_path,
        args.output_op_path,
        args.output_arg_map_path,
    )