parse_utils.py 20.3 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from copy import copy
17 18
from typing import Any, Dict, List, Tuple

19
from tests import is_attr, is_input, is_output, is_vec
20
from type_mapping import opmaker_attr_types_map
21 22 23 24 25 26 27 28 29 30 31 32


def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
    named_dict = {}
    for item in items:
        if "name" not in item:
            raise KeyError(f"name not in {item}")
        name = item["name"]
        named_dict[name] = item
    return named_dict


33
def parse_arg(op_name: str, s: str) -> Dict[str, str]:
34 35 36 37 38
    """parse an argument in following formats:
    1. typename name
    2. typename name = default_value
    """
    typename, rest = [item.strip() for item in s.split(" ", 1)]
39 40
    assert (
        len(typename) > 0
41
    ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml."
42

43 44
    assert (
        rest.count("=") <= 1
45
    ), f"There is more than 1 = in an arg in {op_name}"
46 47
    if rest.count("=") == 1:
        name, default_value = [item.strip() for item in rest.split("=", 1)]
48 49
        assert (
            len(name) > 0
50
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
51 52
        assert (
            len(default_value) > 0
53
        ), f"The default value should not be empty. Please check the args of {op_name} in yaml."
54 55 56
        return {
            "typename": typename,
            "name": name,
57
            "default_value": default_value,
58 59 60
        }
    else:
        name = rest.strip()
61 62
        assert (
            len(name) > 0
63
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
64 65 66
        return {"typename": typename, "name": name}


67
def parse_input_and_attr(
68
    op_name: str, arguments: str
69
) -> Tuple[List, List, Dict, Dict]:
70
    args_str = arguments.strip()
71 72
    assert args_str.startswith('(') and args_str.endswith(')'), (
        f"Args declaration should start with '(' and end with ')', "
73
        f"please check the args of {op_name} in yaml."
74
    )
75 76 77 78 79 80 81 82 83
    args_str = args_str[1:-1]
    args = parse_plain_list(args_str)

    inputs = []
    attrs = []

    met_attr_with_default_value = False

    for arg in args:
84
        item = parse_arg(op_name, arg)
85 86 87
        typename = item["typename"]
        name = item["name"]
        if is_input(typename):
88 89
            assert len(attrs) == 0, (
                f"The input Tensor should appear before attributes. "
90
                f"please check the position of {op_name}:input({name}) "
91 92
                f"in yaml."
            )
93 94 95
            inputs.append(item)
        elif is_attr(typename):
            if met_attr_with_default_value:
96 97
                assert (
                    "default_value" in item
98
                ), f"{op_name}: Arguments with default value should not precede those without default value"
99 100
            elif "default_value" in item:
                met_attr_with_default_value = True
101 102
            if typename.startswith('Scalar') or typename == 'IntArray':
                item['data_type'] = opmaker_attr_types_map[typename]
103 104
            attrs.append(item)
        else:
105
            raise KeyError(f"{op_name}: Invalid argument type {typename}.")
106 107 108
    return inputs, attrs


109
def parse_output(op_name: str, s: str) -> Dict[str, str]:
110 111 112
    """parse an output, typename or typename(name)."""
    match = re.search(
        r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
113 114
        s,
    )
115 116 117 118 119 120 121
    typename = match.group("out_type")
    name = match.group("name")
    size_expr = match.group("expr")

    name = name[1:-1] if name is not None else 'out'
    size_expr = size_expr[1:-1] if size_expr is not None else None

122
    assert is_output(typename), (
123
        f"Invalid output type: {typename} in op : {op_name}."
124 125
        f"Supported types are Tensor and Tensor[]"
    )
126
    if size_expr is not None:
127
        assert is_vec(typename), (
128
            f"Invalid output size: output {name} in op : {op_name} is "
129 130
            f"not a vector but has size expr"
        )
131 132 133 134 135
        return {"typename": typename, "name": name, "size": size_expr}
    else:
        return {"typename": typename, "name": name}


136
def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
137 138 139
    outputs = parse_plain_list(outputs, sep=",")
    output_items = []
    for output in outputs:
140
        output_items.append(parse_output(op_name, output))
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    return output_items


def parse_infer_meta(infer_meta: Dict[str, Any]) -> Dict[str, Any]:
    infer_meta = copy(infer_meta)  # to prevent mutating the input
    if "param" not in infer_meta:
        infer_meta["param"] = None
    return infer_meta


def parse_candidates(s: str) -> Dict[str, Any]:
    "parse candidates joined by either '>'(ordered) or ','(unordered)"
    delimiter = ">" if ">" in s else ","
    ordered = delimiter == ">"
    candidates = parse_plain_list(s, delimiter)
    return {"ordered": ordered, "candidates": candidates}


def parse_plain_list(s: str, sep=",") -> List[str]:
    items = [item.strip() for item in s.strip().split(sep)]
    return items


164
def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
165 166 167 168 169 170
    # kernel :
    #    func : [], Kernel functions (example: scale, scale_sr)
    #    param : [], Input params of kernel
    #    backend : str, the names of param to choose the kernel backend, default is None
    #    layout : str, the names of param to choose the kernel layout, default is None
    #    data_type : str, the names of param to choose the kernel data_type, default is None
171
    #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
172
    kernel = {
173
        'func': [],  # up to 2 function names
174 175 176
        'param': None,
        'backend': None,
        'layout': None,
177
        'data_type': None,
178
        'dispatch': {},
179 180 181 182 183 184 185 186 187 188 189
    }
    if 'param' in kernel_config:
        kernel['param'] = kernel_config['param']

    if 'backend' in kernel_config:
        kernel['backend'] = parse_candidates(kernel_config["backend"])

    if 'layout' in kernel_config:
        kernel['layout'] = parse_candidates(kernel_config["layout"])

    if 'data_type' in kernel_config:
190 191 192 193 194 195 196 197 198 199 200 201 202 203
        data_type_item = parse_candidates(kernel_config["data_type"])
        params_num = len(data_type_item['candidates'])
        data_type_item['to_complex_flag'] = [False] * params_num
        for i in range(params_num):
            complex_match_result = re.match(
                r"complex\((?P<param_name>\w+)\)",
                data_type_item['candidates'][i],
            )
            if complex_match_result:
                data_type_item['candidates'][i] = complex_match_result.group(
                    'param_name'
                )
                data_type_item['to_complex_flag'][i] = True
        kernel['data_type'] = data_type_item
204 205

    kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
206 207
        kernel_config['func']
    )
208 209 210 211 212 213 214 215 216 217 218

    def parse_kernel_in_out_type(in_out_str):
        if len(in_out_str) == 0:
            return None
        tmp_in_out_list = in_out_str[1:-1].split('->')
        inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
        outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]

        # check the tensor type
        for item in inputs:
            assert item in [
219 220 221 222
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
223
            ], f"{op_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
224 225
        for item in outputs:
            assert item in [
226 227 228 229
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
230
            ], f"{op_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
231 232 233 234 235 236

        return (inputs, outputs)

    for func_item in kernel_funcs:
        kernel['func'].append(func_item[0])
        kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
237 238
            func_item[1]
        )
239

240 241 242
    return kernel


243
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
244 245 246 247 248 249 250 251 252
    inplace_map = {}
    inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
    pairs = parse_plain_list(inplace_cfg)
    for pair in pairs:
        in_name, out_name = parse_plain_list(pair, sep="->")
        inplace_map[out_name] = in_name
    return inplace_map


253
def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
254 255 256 257 258 259 260 261 262
    invoke_config = invoke_config.strip()
    func, rest = invoke_config.split("(", 1)
    func = func.strip()
    args = rest.rstrip(")").strip()
    invocation = {"func": func, "args": args}
    return invocation


def extract_type_and_name(records: List[Dict]) -> List[Dict]:
263
    """extract type and name from forward call, it is simpler than forward op ."""
264 265 266
    extracted = [
        {"name": item["name"], "typename": item["typename"]} for item in records
    ]
267 268 269
    return extracted


270 271
def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
    # op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
272
    result = re.search(
273
        r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
274 275
        forward_config,
    )
276 277
    op = result.group("op")
    outputs = parse_outputs(op_name, result.group("outputs"))
278 279
    outputs = extract_type_and_name(outputs)

280
    inputs, attrs = parse_input_and_attr(op_name, result.group("args"))
281 282 283
    inputs = extract_type_and_name(inputs)
    attrs = extract_type_and_name(attrs)
    forward_cfg = {
284
        "name": op,
285 286
        "inputs": inputs,
        "attrs": attrs,
287
        "outputs": outputs,
288 289 290 291
    }
    return forward_cfg


J
Jiabin Yang 已提交
292 293 294 295 296
def parse_composite(
    op_name: str,
    composite_config: str,
) -> Dict[str, Any]:
    # composite_config: func(args1, args2,.....)
297 298 299 300 301 302 303 304
    fname = r'(.*?)'
    wspace = r'\s*'
    fargs = r'(.*?)'
    pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'

    m = re.search(pattern, composite_config)
    func_name = m.group(1)
    func_args = m.group(2)
J
Jiabin Yang 已提交
305 306 307 308 309 310 311

    composite_dict = {}
    composite_dict["func_name"] = func_name
    composite_dict["func_args"] = func_args
    return composite_dict


312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
def check_op_config(op_entry, op_name):
    base_key_set = (
        'op',
        'backward_op',
        'forward',
        'args',
        'output',
        'infer_meta',
        'kernel',
        'backward',
        'invoke',
        'inplace',
        'view',
        'optional',
        'intermediate',
        'no_need_buffer',
        'data_transform',
J
Jiabin Yang 已提交
329
        'composite',
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    )
    infer_meta_key_set = ('func', 'param')
    kernel_key_set = ('func', 'param', 'data_type', 'layout', 'backend')
    for key in op_entry.keys():
        assert (
            key in base_key_set
        ), f"Op ({op_name}) : invalid key ({key}) in Yaml."

    if 'infer_meta' in op_entry:
        for infer_meta_key in op_entry['infer_meta'].keys():
            assert (
                infer_meta_key in infer_meta_key_set
            ), f"Op ({op_name}) : invalid key (infer_meta.{infer_meta_key}) in Yaml."

    if 'kernel' in op_entry:
        for kernel_key in op_entry['kernel'].keys():
            assert (
                kernel_key in kernel_key_set
            ), f"Op ({op_name}) : invalid key (kernel.{kernel_key}) in Yaml."


351 352 353 354
def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
    op_name = op_entry[name_field]
    inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
    outputs = parse_outputs(op_name, op_entry["output"])
J
Jiabin Yang 已提交
355 356
    if "composite" in op_entry:
        composite_dict = parse_composite(op_name, op_entry["composite"])
357
    check_op_config(op_entry, op_name)
358 359 360 361 362 363
    # validate default value of DataType and DataLayout
    for attr in attrs:
        if "default_value" in attr:
            typename = attr["typename"]
            default_value = attr["default_value"]
            if typename == "DataType":
364 365
                assert (
                    "DataType" in default_value
366
                ), f"invalid DataType default value in {op_name}"
367
                # remove namespace
368
                default_value = default_value[default_value.find("DataType") :]
369 370
                attr["default_value"] = default_value
            elif typename == "DataLayout":
371 372
                assert (
                    "DataLayout" in default_value
373
                ), f"invalid DataLayout default value in {op_name}"
374 375 376
                default_value = default_value[
                    default_value.find("DataLayout") :
                ]
377 378 379 380 381 382 383 384 385
                attr["default_value"] = default_value

    input_names = [item["name"] for item in inputs]
    attr_names = [item["name"] for item in attrs]
    output_names = [item["name"] for item in outputs]

    # add optional tag for every input
    for input in inputs:
        input["optional"] = False
386 387 388
    for output in outputs:
        output["optional"] = False

389 390
    if "optional" in op_entry:
        optional_args = parse_plain_list(op_entry["optional"])
391
        for name in optional_args:
392
            assert (
393 394
                name in input_names or name in output_names
            ), f"{op_name} has an optional tensor: '{name}' which is not in input or output."
395 396 397
        for input in inputs:
            if input["name"] in optional_args:
                input["optional"] = True
398 399 400
        for output in outputs:
            if output["name"] in optional_args:
                output["optional"] = True
401 402 403 404

    # add intermediate tag for every output
    for output in outputs:
        output["intermediate"] = False
405 406
    if "intermediate" in op_entry:
        intermediate_outs = parse_plain_list(op_entry["intermediate"])
407
        for name in intermediate_outs:
408 409
            assert (
                name in output_names
410
            ), f"{op_name} has an intermediate output: '{name}' which is not an output."
411 412 413 414 415 416 417
        for output in outputs:
            if output["name"] in intermediate_outs:
                output["intermediate"] = True

    # add no_need_buffer for every input
    for input in inputs:
        input["no_need_buffer"] = False
418 419
    if "no_need_buffer" in op_entry:
        no_buffer_args = parse_plain_list(op_entry["no_need_buffer"])
420
        for name in no_buffer_args:
421 422
            assert (
                name in input_names
423
            ), f"{op_name} has an no buffer input: '{name}' which is not an input."
424 425 426 427 428 429 430 431
        for input in inputs:
            if input["name"] in no_buffer_args:
                input["no_need_buffer"] = True
    else:
        no_buffer_args = None

    # TODO(chenfeiyu): data_transform

432 433
    op = {
        "name": op_name,
434 435 436
        "inputs": inputs,
        "attrs": attrs,
        "outputs": outputs,
437
        "no_need_buffer": no_buffer_args,
438 439
    }

440 441
    # invokes another op ?
    is_base_op = "invoke" not in op_entry
442

443
    if is_base_op:
444
        # kernel
445
        kernel = parse_kernel(op_name, op_entry["kernel"])
446 447 448 449
        if kernel["param"] is None:
            kernel["param"] = input_names + attr_names

        # infer meta
450
        infer_meta = parse_infer_meta(op_entry["infer_meta"])
451 452 453 454
        if infer_meta["param"] is None:
            infer_meta["param"] = copy(kernel["param"])

        # inplace
455 456
        if "inplace" in op_entry:
            inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
457 458
        else:
            inplace_pairs = None
459
        op.update(
460 461 462 463 464 465
            {
                "infer_meta": infer_meta,
                "kernel": kernel,
                "inplace": inplace_pairs,
            }
        )
466 467
    else:
        # invoke
468 469
        invoke = parse_invoke(op_name, op_entry["invoke"])
        op["invoke"] = invoke
470

J
Jiabin Yang 已提交
471 472 473 474
    # has composite ?
    if "composite" in op_entry:
        op.update({"composite": composite_dict})

475
    # backward
476 477
    if "backward" in op_entry:
        backward = op_entry["backward"]
478 479
    else:
        backward = None
480
    op["backward"] = backward
481

482 483 484 485 486
    # forward for backward_ops
    is_backward_op = name_field == "backward_op"
    if is_backward_op:
        if "forward" in op_entry:
            forward = parse_forward(op_name, op_entry["forward"])
487
            # validate_fb
488
            validate_backward_inputs(
489
                op_name, forward["inputs"], forward["outputs"], inputs
490
            )
491 492
            validate_backward_attrs(op_name, forward["attrs"], attrs)
            validate_backward_outputs(op_name, forward["inputs"], outputs)
493 494
        else:
            forward = None
495 496
        op["forward"] = forward
    return op
497 498


499
def validate_backward_attrs(op, forward_attrs, backward_attrs):
500 501 502
    if len(forward_attrs) >= len(backward_attrs):
        return
    num_exceptional_attrs = len(backward_attrs) - len(forward_attrs)
503 504
    # this is a not-that-clean trick to allow backward op to has more attrs
    # than the forward op , as long as they all have default value
505
    for i in range(-num_exceptional_attrs, 0):
506 507
        assert (
            "default_value" in backward_attrs[i]
508
        ), f"{op } has exceptional attr without default value"
509 510


511
def validate_backward_inputs(
512
    op, forward_inputs, forward_outputs, backward_inputs
513
):
514 515 516 517 518
    foward_input_names = [item["name"] for item in forward_inputs]
    forward_output_names = [item["name"] for item in forward_outputs]
    backward_input_names = [item["name"] for item in backward_inputs]

    assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
519
        forward_output_names
520
    ), f"{op } has too many inputs."
521 522


523
def validate_backward_outputs(op, forward_inputs, backward_outputs):
524
    assert len(backward_outputs) <= len(
525
        forward_inputs
526
    ), f"{op } has too many outputs"
527 528


529 530 531 532
def cross_validate(ops):
    for name, op in ops.items():
        if "forward" in op:
            fw_call = op["forward"]
533
            fw_name = fw_call["name"]
534
            if fw_name not in ops:
535
                print(
536
                    f"Something Wrong here, this backward op ({name})'s forward op ({fw_name}) does not exist."
537 538
                )
            else:
539 540
                fw_op = ops[fw_name]
                if "backward" not in fw_op or fw_op["backward"] is None:
541
                    print(
542
                        f"Something Wrong here, {name}'s forward op ({fw_name}) does not claim {name} as its backward."
543 544
                    )
                else:
545
                    assert (
546
                        fw_op["backward"] == name
547
                    ), f"{name}: backward and forward name mismatch"
548 549

                assert len(fw_call["inputs"]) <= len(
550 551 552
                    fw_op["inputs"]
                ), f"{name}: forward call has more inputs than the op "
                for (input, input_) in zip(fw_call["inputs"], fw_op["inputs"]):
553 554 555
                    assert (
                        input["typename"] == input_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"
556 557

                assert len(fw_call["attrs"]) <= len(
558 559 560
                    fw_op["attrs"]
                ), f"{name}: forward call has more attrs than the op "
                for (attr, attr_) in zip(fw_call["attrs"], fw_op["attrs"]):
561 562 563 564 565 566
                    if attr["typename"] == "Scalar":
                        # special case for Scalar, fw_call can omit the type
                        assert re.match(
                            r"Scalar(\(\w+\))*", attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
                    else:
567 568 569
                        assert (
                            attr["typename"] == attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
570 571

                assert len(fw_call["outputs"]) == len(
572 573
                    fw_op["outputs"]
                ), f"{name}: forward call has more outputs than the op "
574
                for (output, output_) in zip(
575
                    fw_call["outputs"], fw_op["outputs"]
576 577 578 579
                ):
                    assert (
                        output["typename"] == output_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"