parse_utils.py 20.0 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from copy import copy
17 18
from typing import Any, Dict, List, Tuple

19
from tests import is_attr, is_input, is_output, is_vec
20
from type_mapping import opmaker_attr_types_map
21 22 23 24 25 26 27 28 29 30 31 32


def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
    named_dict = {}
    for item in items:
        if "name" not in item:
            raise KeyError(f"name not in {item}")
        name = item["name"]
        named_dict[name] = item
    return named_dict


33
def parse_arg(op_name: str, s: str) -> Dict[str, str]:
34 35 36 37 38
    """parse an argument in following formats:
    1. typename name
    2. typename name = default_value
    """
    typename, rest = [item.strip() for item in s.split(" ", 1)]
39 40
    assert (
        len(typename) > 0
41
    ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml."
42

43 44
    assert (
        rest.count("=") <= 1
45
    ), f"There is more than 1 = in an arg in {op_name}"
46 47
    if rest.count("=") == 1:
        name, default_value = [item.strip() for item in rest.split("=", 1)]
48 49
        assert (
            len(name) > 0
50
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
51 52
        assert (
            len(default_value) > 0
53
        ), f"The default value should not be empty. Please check the args of {op_name} in yaml."
54 55 56
        return {
            "typename": typename,
            "name": name,
57
            "default_value": default_value,
58 59 60
        }
    else:
        name = rest.strip()
61 62
        assert (
            len(name) > 0
63
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
64 65 66
        return {"typename": typename, "name": name}


67
def parse_input_and_attr(
68
    op_name: str, arguments: str
69
) -> Tuple[List, List, Dict, Dict]:
70
    args_str = arguments.strip()
71 72
    assert args_str.startswith('(') and args_str.endswith(')'), (
        f"Args declaration should start with '(' and end with ')', "
73
        f"please check the args of {op_name} in yaml."
74
    )
75 76 77 78 79 80 81 82 83
    args_str = args_str[1:-1]
    args = parse_plain_list(args_str)

    inputs = []
    attrs = []

    met_attr_with_default_value = False

    for arg in args:
84
        item = parse_arg(op_name, arg)
85 86 87
        typename = item["typename"]
        name = item["name"]
        if is_input(typename):
88 89
            assert len(attrs) == 0, (
                f"The input Tensor should appear before attributes. "
90
                f"please check the position of {op_name}:input({name}) "
91 92
                f"in yaml."
            )
93 94 95
            inputs.append(item)
        elif is_attr(typename):
            if met_attr_with_default_value:
96 97
                assert (
                    "default_value" in item
98
                ), f"{op_name}: Arguments with default value should not precede those without default value"
99 100
            elif "default_value" in item:
                met_attr_with_default_value = True
101 102
            if typename.startswith('Scalar') or typename == 'IntArray':
                item['data_type'] = opmaker_attr_types_map[typename]
103 104
            attrs.append(item)
        else:
105
            raise KeyError(f"{op_name}: Invalid argument type {typename}.")
106 107 108
    return inputs, attrs


109
def parse_output(op_name: str, s: str) -> Dict[str, str]:
110 111 112
    """parse an output, typename or typename(name)."""
    match = re.search(
        r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
113 114
        s,
    )
115 116 117 118 119 120 121
    typename = match.group("out_type")
    name = match.group("name")
    size_expr = match.group("expr")

    name = name[1:-1] if name is not None else 'out'
    size_expr = size_expr[1:-1] if size_expr is not None else None

122
    assert is_output(typename), (
123
        f"Invalid output type: {typename} in op : {op_name}."
124 125
        f"Supported types are Tensor and Tensor[]"
    )
126
    if size_expr is not None:
127
        assert is_vec(typename), (
128
            f"Invalid output size: output {name} in op : {op_name} is "
129 130
            f"not a vector but has size expr"
        )
131 132 133 134 135
        return {"typename": typename, "name": name, "size": size_expr}
    else:
        return {"typename": typename, "name": name}


136
def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
137 138 139
    outputs = parse_plain_list(outputs, sep=",")
    output_items = []
    for output in outputs:
140
        output_items.append(parse_output(op_name, output))
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    return output_items


def parse_infer_meta(infer_meta: Dict[str, Any]) -> Dict[str, Any]:
    infer_meta = copy(infer_meta)  # to prevent mutating the input
    if "param" not in infer_meta:
        infer_meta["param"] = None
    return infer_meta


def parse_candidates(s: str) -> Dict[str, Any]:
    "parse candidates joined by either '>'(ordered) or ','(unordered)"
    delimiter = ">" if ">" in s else ","
    ordered = delimiter == ">"
    candidates = parse_plain_list(s, delimiter)
    return {"ordered": ordered, "candidates": candidates}


def parse_plain_list(s: str, sep=",") -> List[str]:
    items = [item.strip() for item in s.strip().split(sep)]
    return items


164
def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
165 166 167 168 169 170
    # kernel :
    #    func : [], Kernel functions (example: scale, scale_sr)
    #    param : [], Input params of kernel
    #    backend : str, the names of param to choose the kernel backend, default is None
    #    layout : str, the names of param to choose the kernel layout, default is None
    #    data_type : str, the names of param to choose the kernel data_type, default is None
171
    #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
172
    kernel = {
173
        'func': [],  # up to 2 function names
174 175 176
        'param': None,
        'backend': None,
        'layout': None,
177
        'data_type': None,
178
        'dispatch': {},
179 180 181 182 183 184 185 186 187 188 189
    }
    if 'param' in kernel_config:
        kernel['param'] = kernel_config['param']

    if 'backend' in kernel_config:
        kernel['backend'] = parse_candidates(kernel_config["backend"])

    if 'layout' in kernel_config:
        kernel['layout'] = parse_candidates(kernel_config["layout"])

    if 'data_type' in kernel_config:
190 191 192 193 194 195 196 197 198 199 200 201 202 203
        data_type_item = parse_candidates(kernel_config["data_type"])
        params_num = len(data_type_item['candidates'])
        data_type_item['to_complex_flag'] = [False] * params_num
        for i in range(params_num):
            complex_match_result = re.match(
                r"complex\((?P<param_name>\w+)\)",
                data_type_item['candidates'][i],
            )
            if complex_match_result:
                data_type_item['candidates'][i] = complex_match_result.group(
                    'param_name'
                )
                data_type_item['to_complex_flag'][i] = True
        kernel['data_type'] = data_type_item
204 205

    kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
206 207
        kernel_config['func']
    )
208 209 210 211 212 213 214 215 216 217 218

    def parse_kernel_in_out_type(in_out_str):
        if len(in_out_str) == 0:
            return None
        tmp_in_out_list = in_out_str[1:-1].split('->')
        inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
        outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]

        # check the tensor type
        for item in inputs:
            assert item in [
219 220 221 222
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
223
            ], f"{op_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
224 225
        for item in outputs:
            assert item in [
226 227 228 229
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
230
            ], f"{op_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
231 232 233 234 235 236

        return (inputs, outputs)

    for func_item in kernel_funcs:
        kernel['func'].append(func_item[0])
        kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
237 238
            func_item[1]
        )
239

240 241 242
    return kernel


243
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
244 245 246 247 248 249 250 251 252
    inplace_map = {}
    inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
    pairs = parse_plain_list(inplace_cfg)
    for pair in pairs:
        in_name, out_name = parse_plain_list(pair, sep="->")
        inplace_map[out_name] = in_name
    return inplace_map


253
def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
254 255 256 257 258 259 260 261 262
    invoke_config = invoke_config.strip()
    func, rest = invoke_config.split("(", 1)
    func = func.strip()
    args = rest.rstrip(")").strip()
    invocation = {"func": func, "args": args}
    return invocation


def extract_type_and_name(records: List[Dict]) -> List[Dict]:
263
    """extract type and name from forward call, it is simpler than forward op ."""
264 265 266
    extracted = [
        {"name": item["name"], "typename": item["typename"]} for item in records
    ]
267 268 269
    return extracted


270 271
def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
    # op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
272
    result = re.search(
273
        r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
274 275
        forward_config,
    )
276 277
    op = result.group("op")
    outputs = parse_outputs(op_name, result.group("outputs"))
278 279
    outputs = extract_type_and_name(outputs)

280
    inputs, attrs = parse_input_and_attr(op_name, result.group("args"))
281 282 283
    inputs = extract_type_and_name(inputs)
    attrs = extract_type_and_name(attrs)
    forward_cfg = {
284
        "name": op,
285 286
        "inputs": inputs,
        "attrs": attrs,
287
        "outputs": outputs,
288 289 290 291
    }
    return forward_cfg


J
Jiabin Yang 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
def parse_composite(
    op_name: str,
    composite_config: str,
) -> Dict[str, Any]:
    # composite_config: func(args1, args2,.....)
    fname = r'(.*?)'
    wspace = r'\s*'
    fargs = r'(.*?)'
    pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'

    m = re.search(pattern, composite_config)
    func_name = m.group(1)
    func_args = m.group(2)

    composite_dict = {}
    composite_dict["func_name"] = func_name
    composite_dict["func_args"] = func_args
    return composite_dict


312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
def check_op_config(op_entry, op_name):
    base_key_set = (
        'op',
        'backward_op',
        'forward',
        'args',
        'output',
        'infer_meta',
        'kernel',
        'backward',
        'invoke',
        'inplace',
        'view',
        'optional',
        'intermediate',
        'no_need_buffer',
        'data_transform',
J
Jiabin Yang 已提交
329
        'composite',
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    )
    infer_meta_key_set = ('func', 'param')
    kernel_key_set = ('func', 'param', 'data_type', 'layout', 'backend')
    for key in op_entry.keys():
        assert (
            key in base_key_set
        ), f"Op ({op_name}) : invalid key ({key}) in Yaml."

    if 'infer_meta' in op_entry:
        for infer_meta_key in op_entry['infer_meta'].keys():
            assert (
                infer_meta_key in infer_meta_key_set
            ), f"Op ({op_name}) : invalid key (infer_meta.{infer_meta_key}) in Yaml."

    if 'kernel' in op_entry:
        for kernel_key in op_entry['kernel'].keys():
            assert (
                kernel_key in kernel_key_set
            ), f"Op ({op_name}) : invalid key (kernel.{kernel_key}) in Yaml."


351 352 353 354
def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
    op_name = op_entry[name_field]
    inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
    outputs = parse_outputs(op_name, op_entry["output"])
J
Jiabin Yang 已提交
355 356
    if "composite" in op_entry:
        composite_dict = parse_composite(op_name, op_entry["composite"])
357
    check_op_config(op_entry, op_name)
358 359 360 361 362 363
    # validate default value of DataType and DataLayout
    for attr in attrs:
        if "default_value" in attr:
            typename = attr["typename"]
            default_value = attr["default_value"]
            if typename == "DataType":
364 365
                assert (
                    "DataType" in default_value
366
                ), f"invalid DataType default value in {op_name}"
367
                # remove namespace
368
                default_value = default_value[default_value.find("DataType") :]
369 370
                attr["default_value"] = default_value
            elif typename == "DataLayout":
371 372
                assert (
                    "DataLayout" in default_value
373
                ), f"invalid DataLayout default value in {op_name}"
374 375 376
                default_value = default_value[
                    default_value.find("DataLayout") :
                ]
377 378 379 380 381 382 383 384 385
                attr["default_value"] = default_value

    input_names = [item["name"] for item in inputs]
    attr_names = [item["name"] for item in attrs]
    output_names = [item["name"] for item in outputs]

    # add optional tag for every input
    for input in inputs:
        input["optional"] = False
386 387
    if "optional" in op_entry:
        optional_args = parse_plain_list(op_entry["optional"])
388
        for name in optional_args:
389 390
            assert (
                name in input_names
391
            ), f"{op_name} has an optional input: '{name}' which is not an input."
392 393 394 395 396 397 398
        for input in inputs:
            if input["name"] in optional_args:
                input["optional"] = True

    # add intermediate tag for every output
    for output in outputs:
        output["intermediate"] = False
399 400
    if "intermediate" in op_entry:
        intermediate_outs = parse_plain_list(op_entry["intermediate"])
401
        for name in intermediate_outs:
402 403
            assert (
                name in output_names
404
            ), f"{op_name} has an intermediate output: '{name}' which is not an output."
405 406 407 408 409 410 411
        for output in outputs:
            if output["name"] in intermediate_outs:
                output["intermediate"] = True

    # add no_need_buffer for every input
    for input in inputs:
        input["no_need_buffer"] = False
412 413
    if "no_need_buffer" in op_entry:
        no_buffer_args = parse_plain_list(op_entry["no_need_buffer"])
414
        for name in no_buffer_args:
415 416
            assert (
                name in input_names
417
            ), f"{op_name} has an no buffer input: '{name}' which is not an input."
418 419 420 421 422 423 424 425
        for input in inputs:
            if input["name"] in no_buffer_args:
                input["no_need_buffer"] = True
    else:
        no_buffer_args = None

    # TODO(chenfeiyu): data_transform

426 427
    op = {
        "name": op_name,
428 429 430
        "inputs": inputs,
        "attrs": attrs,
        "outputs": outputs,
431
        "no_need_buffer": no_buffer_args,
432 433
    }

434 435
    # invokes another op ?
    is_base_op = "invoke" not in op_entry
436

437
    if is_base_op:
438
        # kernel
439
        kernel = parse_kernel(op_name, op_entry["kernel"])
440 441 442 443
        if kernel["param"] is None:
            kernel["param"] = input_names + attr_names

        # infer meta
444
        infer_meta = parse_infer_meta(op_entry["infer_meta"])
445 446 447 448
        if infer_meta["param"] is None:
            infer_meta["param"] = copy(kernel["param"])

        # inplace
449 450
        if "inplace" in op_entry:
            inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
451 452
        else:
            inplace_pairs = None
453
        op.update(
454 455 456 457 458 459
            {
                "infer_meta": infer_meta,
                "kernel": kernel,
                "inplace": inplace_pairs,
            }
        )
460 461
    else:
        # invoke
462 463
        invoke = parse_invoke(op_name, op_entry["invoke"])
        op["invoke"] = invoke
464

J
Jiabin Yang 已提交
465 466 467 468
    # has composite ?
    if "composite" in op_entry:
        op.update({"composite": composite_dict})

469
    # backward
470 471
    if "backward" in op_entry:
        backward = op_entry["backward"]
472 473
    else:
        backward = None
474
    op["backward"] = backward
475

476 477 478 479 480
    # forward for backward_ops
    is_backward_op = name_field == "backward_op"
    if is_backward_op:
        if "forward" in op_entry:
            forward = parse_forward(op_name, op_entry["forward"])
481
            # validate_fb
482
            validate_backward_inputs(
483
                op_name, forward["inputs"], forward["outputs"], inputs
484
            )
485 486
            validate_backward_attrs(op_name, forward["attrs"], attrs)
            validate_backward_outputs(op_name, forward["inputs"], outputs)
487 488
        else:
            forward = None
489 490
        op["forward"] = forward
    return op
491 492


493
def validate_backward_attrs(op, forward_attrs, backward_attrs):
494 495 496
    if len(forward_attrs) >= len(backward_attrs):
        return
    num_exceptional_attrs = len(backward_attrs) - len(forward_attrs)
497 498
    # this is a not-that-clean trick to allow backward op to has more attrs
    # than the forward op , as long as they all have default value
499
    for i in range(-num_exceptional_attrs, 0):
500 501
        assert (
            "default_value" in backward_attrs[i]
502
        ), f"{op } has exceptional attr without default value"
503 504


505
def validate_backward_inputs(
506
    op, forward_inputs, forward_outputs, backward_inputs
507
):
508 509 510 511 512
    foward_input_names = [item["name"] for item in forward_inputs]
    forward_output_names = [item["name"] for item in forward_outputs]
    backward_input_names = [item["name"] for item in backward_inputs]

    assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
513
        forward_output_names
514
    ), f"{op } has too many inputs."
515 516


517
def validate_backward_outputs(op, forward_inputs, backward_outputs):
518
    assert len(backward_outputs) <= len(
519
        forward_inputs
520
    ), f"{op } has too many outputs"
521 522


523 524 525 526
def cross_validate(ops):
    for name, op in ops.items():
        if "forward" in op:
            fw_call = op["forward"]
527
            fw_name = fw_call["name"]
528
            if fw_name not in ops:
529
                print(
530
                    f"Something Wrong here, this backward op ({name})'s forward op ({fw_name}) does not exist."
531 532
                )
            else:
533 534
                fw_op = ops[fw_name]
                if "backward" not in fw_op or fw_op["backward"] is None:
535
                    print(
536
                        f"Something Wrong here, {name}'s forward op ({fw_name}) does not claim {name} as its backward."
537 538
                    )
                else:
539
                    assert (
540
                        fw_op["backward"] == name
541
                    ), f"{name}: backward and forward name mismatch"
542 543

                assert len(fw_call["inputs"]) <= len(
544 545 546
                    fw_op["inputs"]
                ), f"{name}: forward call has more inputs than the op "
                for (input, input_) in zip(fw_call["inputs"], fw_op["inputs"]):
547 548 549
                    assert (
                        input["typename"] == input_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"
550 551

                assert len(fw_call["attrs"]) <= len(
552 553 554
                    fw_op["attrs"]
                ), f"{name}: forward call has more attrs than the op "
                for (attr, attr_) in zip(fw_call["attrs"], fw_op["attrs"]):
555 556 557 558 559 560
                    if attr["typename"] == "Scalar":
                        # special case for Scalar, fw_call can omit the type
                        assert re.match(
                            r"Scalar(\(\w+\))*", attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
                    else:
561 562 563
                        assert (
                            attr["typename"] == attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
564 565

                assert len(fw_call["outputs"]) == len(
566 567
                    fw_op["outputs"]
                ), f"{name}: forward call has more outputs than the op "
568
                for (output, output_) in zip(
569
                    fw_call["outputs"], fw_op["outputs"]
570 571 572 573
                ):
                    assert (
                        output["typename"] == output_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"