generate_op.py 15.6 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path

import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined

22 23 24 25 26 27 28
from filters import (
    to_op_attr_type,
    to_opmaker_name,
    to_opmaker_name_cstr,
    to_pascal_case,
)
from tests import (
29
    is_base_op,
30 31 32 33 34 35
    is_vec,
    is_scalar,
    is_initializer_list,
    supports_inplace,
    supports_no_need_buffer,
)
36
from filters import to_input_name, cartesian_prod_mapping
37 38 39
from parse_utils import to_named_dict

file_loader = FileSystemLoader(Path(__file__).parent / "templates")
40 41 42 43 44 45 46 47
env = Environment(
    loader=file_loader,
    keep_trailing_newline=True,
    trim_blocks=True,
    lstrip_blocks=True,
    undefined=StrictUndefined,
    extensions=['jinja2.ext.do'],
)
48 49 50 51 52
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
53
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
54
env.tests["base_op"] = is_base_op
55 56 57 58 59 60 61
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer


62 63 64 65 66
def restruct_io(op):
    op["input_dict"] = to_named_dict(op["inputs"])
    op["attr_dict"] = to_named_dict(op["attrs"])
    op["output_dict"] = to_named_dict(op["outputs"])
    return op
67 68


69
# replace name of op and params for OpMaker
70 71 72
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
    def get_op_and_op_name(op_item):
        names = op_item.split('(')
73 74 75 76 77
        if len(names) == 1:
            return names[0].strip(), names[0].strip()
        else:
            return names[0].strip(), names[1].split(')')[0].strip()

78
    def update_op_attr_name(attrs, attrs_alias_map):
79 80 81 82
        for attr_item in attrs:
            if attr_item['name'] in attrs_alias_map:
                attr_item['name'] = attrs_alias_map[attr_item['name']]

83 84 85
    for op_args in op_op_map:
        new_op_name, op_name = get_op_and_op_name(op_args['op'])
        if new_op_name not in forward_op_dict:
86
            continue
87 88
        forward_op_item = forward_op_dict[new_op_name]
        has_backward = True if forward_op_item['backward'] else False
89
        if has_backward:
90 91 92 93 94 95 96 97
            backward_op_item = backward_op_dict[forward_op_item['backward']]
        if new_op_name != op_name:
            forward_op_item['op_name'] = op_name
        if 'backward' in op_args and has_backward:
            backward_op_list = op_args['backward'].split(',')
            _, bw_op_name = get_op_and_op_name(backward_op_list[0])
            forward_op_item['backward'] = bw_op_name
            backward_op_item['op_name'] = bw_op_name
98

99 100
            # for double grad
            if len(backward_op_list) > 1:
101 102 103 104 105 106
                (
                    new_double_grad_op_name,
                    double_grad_op_name,
                ) = get_op_and_op_name(backward_op_list[1])
                double_grad_item = backward_op_dict[new_double_grad_op_name]
                backward_op_item['backward'] = double_grad_op_name
107
                double_grad_item['op_name'] = double_grad_op_name
108 109 110
                if 'attrs' in op_args:
                    update_op_attr_name(
                        double_grad_item['attrs'], op_args['attrs']
111
                    )
112 113
                    update_op_attr_name(
                        double_grad_item['forward']['attrs'], op_args['attrs']
114
                    )
115 116 117 118

                # for triple grad
                if len(backward_op_list) > 2:
                    (
119
                        new_triple_grad_op_name,
120
                        triple_grad_op_name,
121 122
                    ) = get_op_and_op_name(backward_op_list[2])
                    triple_grad_item = backward_op_dict[new_triple_grad_op_name]
123 124
                    double_grad_item['backward'] = triple_grad_op_name
                    triple_grad_item['op_name'] = triple_grad_op_name
125 126 127
                    if 'attrs' in op_args:
                        update_op_attr_name(
                            triple_grad_item['attrs'], op_args['attrs']
128
                        )
129
                        update_op_attr_name(
130
                            triple_grad_item['forward']['attrs'],
131
                            op_args['attrs'],
132
                        )
133

134 135 136
        key_set = ['inputs', 'attrs', 'outputs']
        args_map = {}
        for key in key_set:
137 138 139 140 141
            if key in op_args:
                args_map.update(op_args[key])
                for args_item in forward_op_item[key]:
                    if args_item['name'] in op_args[key]:
                        args_item['name'] = op_args[key][args_item['name']]
142
                if has_backward:
143 144 145 146
                    for args_item in backward_op_item['forward'][key]:
                        if args_item['name'] in op_args[key]:
                            args_item['name'] = op_args[key][args_item['name']]
        forward_op_item['infer_meta']['param'] = [
147
            args_map[param] if param in args_map else param
148
            for param in forward_op_item['infer_meta']['param']
149
        ]
150
        forward_op_item['kernel']['param'] = [
151
            args_map[param] if param in args_map else param
152
            for param in forward_op_item['kernel']['param']
153
        ]
154 155
        if forward_op_item['kernel']['data_type']:
            forward_op_item['kernel']['data_type']['candidates'] = [
156
                args_map[param] if param in args_map else param
157
                for param in forward_op_item['kernel']['data_type'][
158 159
                    'candidates'
                ]
160
            ]
161 162
        if forward_op_item['kernel']['backend']:
            forward_op_item['kernel']['backend']['candidates'] = [
163
                args_map[param] if param in args_map else param
164
                for param in forward_op_item['kernel']['backend']['candidates']
165
            ]
166 167
        if forward_op_item['kernel']['layout']:
            forward_op_item['kernel']['layout']['candidates'] = [
168
                args_map[param] if param in args_map else param
169
                for param in forward_op_item['kernel']['layout']['candidates']
170
            ]
171
        if forward_op_item['inplace']:
172
            inplace_map = {}
173
            for key, val in forward_op_item['inplace'].items():
174 175 176 177 178
                if key in args_map:
                    key = args_map[key]
                if val in args_map:
                    val = args_map[val]
                inplace_map[key] = val
179
            forward_op_item['inplace'] = inplace_map
180 181

        if has_backward:
182
            for args_item in backward_op_item['inputs']:
183 184
                if args_item['name'] in args_map:
                    args_item['name'] = args_map[args_item['name']]
185 186 187 188 189 190 191
                elif (
                    args_item['name'].endswith('_grad')
                    and args_item['name'][:-5] in args_map
                ):
                    args_map[args_item['name']] = (
                        args_map[args_item['name'][:-5]] + '_grad'
                    )
192
                    args_item['name'] = args_map[args_item['name']]
193
            for args_item in backward_op_item['attrs']:
194 195
                if args_item['name'] in args_map:
                    args_item['name'] = args_map[args_item['name']]
196
            for args_item in backward_op_item['outputs']:
197 198 199 200 201 202 203
                if (
                    args_item['name'].endswith('_grad')
                    and args_item['name'][:-5] in args_map
                ):
                    args_map[args_item['name']] = (
                        args_map[args_item['name'][:-5]] + '_grad'
                    )
204 205
                    args_item['name'] = args_map[args_item['name']]

206 207
            if 'invoke' in backward_op_item:
                backward_op_item['invoke']['args'] = [
208
                    args_map[param.strip()]
209 210
                    if param.strip() in args_map
                    else param.strip()
211
                    for param in backward_op_item['invoke']['args'].split(',')
212 213 214
                ]
                continue

215
            backward_op_item['infer_meta']['param'] = [
216
                args_map[param] if param in args_map else param
217
                for param in backward_op_item['infer_meta']['param']
218
            ]
219
            backward_op_item['kernel']['param'] = [
220
                args_map[param] if param in args_map else param
221
                for param in backward_op_item['kernel']['param']
222
            ]
223 224
            if backward_op_item['kernel']['data_type']:
                backward_op_item['kernel']['data_type']['candidates'] = [
225
                    args_map[param] if param in args_map else param
226
                    for param in backward_op_item['kernel']['data_type'][
227 228
                        'candidates'
                    ]
229
                ]
230 231
            if backward_op_item['kernel']['backend']:
                backward_op_item['kernel']['backend']['candidates'] = [
232
                    args_map[param] if param in args_map else param
233
                    for param in backward_op_item['kernel']['backend'][
234 235
                        'candidates'
                    ]
236
                ]
237 238
            if backward_op_item['kernel']['layout']:
                backward_op_item['kernel']['layout']['candidates'] = [
239
                    args_map[param] if param in args_map else param
240
                    for param in backward_op_item['kernel']['layout'][
241 242
                        'candidates'
                    ]
243
                ]
244 245
            if backward_op_item['no_need_buffer']:
                backward_op_item['no_need_buffer'] = [
246
                    args_map[param] if param in args_map else param
247
                    for param in backward_op_item['no_need_buffer']
248
                ]
249
            if backward_op_item['inplace']:
250
                inplace_map = {}
251
                for key, val in backward_op_item['inplace'].items():
252 253 254 255 256
                    if key in args_map:
                        key = args_map[key]
                    if val in args_map:
                        val = args_map[val]
                    inplace_map[key] = val
257
                backward_op_item['inplace'] = inplace_map
258

259

260 261 262 263 264
def process_invoke_op(forward_op_dict, backward_op_dict):
    for bw_op in backward_op_dict.values():
        if 'invoke' in bw_op:
            invoke_op = bw_op['invoke']['func']
            args_list = bw_op['invoke']['args']
265
            args_index = 0
266 267 268 269 270
            if invoke_op in forward_op_dict:
                reuse_op = forward_op_dict[invoke_op]
                bw_op['invoke']['inputs'] = []
                bw_op['invoke']['attrs'] = []
                bw_op['invoke']['outputs'] = []
271
                for input_item in reuse_op['inputs']:
272
                    bw_op['invoke']['inputs'].append(
273 274 275 276 277
                        {
                            'name': input_item['name'],
                            'value': args_list[args_index],
                        }
                    )
278 279 280
                    args_index = args_index + 1
                for attr in reuse_op['attrs']:
                    if args_index < len(args_list):
281 282
                        attr_value = (
                            f"this->GetAttr(\"{args_list[args_index]}\")"
283
                            if args_list[args_index] in bw_op['attr_dict']
284 285
                            else args_list[args_index]
                        )
286
                        bw_op['invoke']['attrs'].append(
287 288
                            {'name': attr['name'], 'value': attr_value}
                        )
289 290 291 292
                        args_index = args_index + 1
                    else:
                        break
                for idx, output_item in enumerate(reuse_op['outputs']):
293
                    bw_op['invoke']['outputs'].append(
294 295
                        {
                            'name': output_item['name'],
296
                            'value': bw_op['outputs'][idx]['name'],
297 298 299 300 301 302 303 304 305 306 307 308
                        }
                    )


def main(
    ops_yaml_path,
    backward_yaml_path,
    op_compat_yaml_path,
    op_version_yaml_path,
    output_op_path,
    output_arg_map_path,
):
309
    with open(ops_yaml_path, "rt") as f:
310 311 312
        ops = yaml.safe_load(f)
        ops = [restruct_io(op) for op in ops]
    forward_op_dict = to_named_dict(ops)
313 314

    with open(backward_yaml_path, "rt") as f:
315 316 317
        backward_ops = yaml.safe_load(f)
        backward_ops = [restruct_io(op) for op in backward_ops]
    backward_op_dict = to_named_dict(backward_ops)
318 319

    with open(op_version_yaml_path, "rt") as f:
320 321 322 323
        op_versions = yaml.safe_load(f)
    # add op version info into op
    for op_version in op_versions:
        forward_op_dict[op_version['op']]['version'] = op_version['version']
324 325

    with open(op_compat_yaml_path, "rt") as f:
326
        op_op_map = yaml.safe_load(f)
327

328 329 330 331
    for op in ops:
        op['op_name'] = op['name']
    for bw_op in backward_ops:
        bw_op['op_name'] = bw_op['name']
332

333
    replace_compat_name(op_op_map, forward_op_dict, backward_op_dict)
334 335

    # prepare for invoke case
336
    process_invoke_op(forward_op_dict, backward_op_dict)
337

338 339 340 341 342 343 344
    # fill backward field for an op if another op claims it as forward
    for name, backward_op in backward_op_dict.items():
        forward_name = backward_op["forward"]["name"]
        if forward_name in backward_op_dict:
            forward_op = backward_op_dict[forward_name]
            if forward_op["backward"] is None:
                forward_op["backward"] = name
345

346 347 348
    op_dict = {}
    op_dict.update(forward_op_dict)
    op_dict.update(backward_op_dict)
349

350
    if len(ops) == 0 and len(backward_ops) == 0:
351 352 353 354 355 356 357 358
        if os.path.isfile(output_op_path):
            os.remove(output_op_path)
        if os.path.isfile(output_arg_map_path):
            os.remove(output_arg_map_path)
        return

    op_template = env.get_template('op.c.j2')
    with open(output_op_path, "wt") as f:
359
        msg = op_template.render(
360
            ops=ops, backward_ops=backward_ops, op_dict=op_dict
361
        )
362 363 364 365
        f.write(msg)

    ks_template = env.get_template('ks.c.j2')
    with open(output_arg_map_path, 'wt') as f:
366
        msg = ks_template.render(ops=ops, backward_ops=backward_ops)
367 368 369 370 371
        f.write(msg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
372
        description="Generate operator file from op yaml."
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
    )
    parser.add_argument(
        '--ops_yaml_path', type=str, help="parsed ops yaml file."
    )
    parser.add_argument(
        '--backward_yaml_path', type=str, help="parsed backward ops yaml file."
    )
    parser.add_argument(
        '--op_compat_yaml_path', type=str, help="ops args compat yaml file."
    )
    parser.add_argument(
        '--op_version_yaml_path', type=str, help="ops version yaml file."
    )
    parser.add_argument(
        "--output_op_path", type=str, help="path to save generated operators."
    )
389 390 391
    parser.add_argument(
        "--output_arg_map_path",
        type=str,
392 393
        help="path to save generated argument mapping functions.",
    )
394 395

    args = parser.parse_args()
396 397 398 399 400 401 402 403
    main(
        args.ops_yaml_path,
        args.backward_yaml_path,
        args.op_compat_yaml_path,
        args.op_version_yaml_path,
        args.output_op_path,
        args.output_arg_map_path,
    )