parse_utils.py 17.5 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from copy import copy
from typing import Dict, Any, List, Tuple
from tests import is_attr, is_input, is_output, is_vec


def to_named_dict(items: List[Dict]) -> Dict[str, Dict]:
    named_dict = {}
    for item in items:
        if "name" not in item:
            raise KeyError(f"name not in {item}")
        name = item["name"]
        named_dict[name] = item
    return named_dict


31
def parse_arg(op_name: str, s: str) -> Dict[str, str]:
32 33 34 35 36
    """parse an argument in following formats:
    1. typename name
    2. typename name = default_value
    """
    typename, rest = [item.strip() for item in s.split(" ", 1)]
37 38
    assert (
        len(typename) > 0
39
    ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml."
40

41 42
    assert (
        rest.count("=") <= 1
43
    ), f"There is more than 1 = in an arg in {op_name}"
44 45
    if rest.count("=") == 1:
        name, default_value = [item.strip() for item in rest.split("=", 1)]
46 47
        assert (
            len(name) > 0
48
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
49 50
        assert (
            len(default_value) > 0
51
        ), f"The default value should not be empty. Please check the args of {op_name} in yaml."
52 53 54
        return {
            "typename": typename,
            "name": name,
55
            "default_value": default_value,
56 57 58
        }
    else:
        name = rest.strip()
59 60
        assert (
            len(name) > 0
61
        ), f"The arg name should not be empty. Please check the args of {op_name} in yaml."
62 63 64
        return {"typename": typename, "name": name}


65
def parse_input_and_attr(
66
    op_name: str, arguments: str
67
) -> Tuple[List, List, Dict, Dict]:
68
    args_str = arguments.strip()
69 70
    assert args_str.startswith('(') and args_str.endswith(')'), (
        f"Args declaration should start with '(' and end with ')', "
71
        f"please check the args of {op_name} in yaml."
72
    )
73 74 75 76 77 78 79 80 81
    args_str = args_str[1:-1]
    args = parse_plain_list(args_str)

    inputs = []
    attrs = []

    met_attr_with_default_value = False

    for arg in args:
82
        item = parse_arg(op_name, arg)
83 84 85
        typename = item["typename"]
        name = item["name"]
        if is_input(typename):
86 87
            assert len(attrs) == 0, (
                f"The input Tensor should appear before attributes. "
88
                f"please check the position of {op_name}:input({name}) "
89 90
                f"in yaml."
            )
91 92 93
            inputs.append(item)
        elif is_attr(typename):
            if met_attr_with_default_value:
94 95
                assert (
                    "default_value" in item
96
                ), f"{op_name}: Arguments with default value should not precede those without default value"
97 98 99 100
            elif "default_value" in item:
                met_attr_with_default_value = True
            attrs.append(item)
        else:
101
            raise KeyError(f"{op_name}: Invalid argument type {typename}.")
102 103 104
    return inputs, attrs


105
def parse_output(op_name: str, s: str) -> Dict[str, str]:
106 107 108
    """parse an output, typename or typename(name)."""
    match = re.search(
        r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
109 110
        s,
    )
111 112 113 114 115 116 117
    typename = match.group("out_type")
    name = match.group("name")
    size_expr = match.group("expr")

    name = name[1:-1] if name is not None else 'out'
    size_expr = size_expr[1:-1] if size_expr is not None else None

118
    assert is_output(typename), (
119
        f"Invalid output type: {typename} in op : {op_name}."
120 121
        f"Supported types are Tensor and Tensor[]"
    )
122
    if size_expr is not None:
123
        assert is_vec(typename), (
124
            f"Invalid output size: output {name} in op : {op_name} is "
125 126
            f"not a vector but has size expr"
        )
127 128 129 130 131
        return {"typename": typename, "name": name, "size": size_expr}
    else:
        return {"typename": typename, "name": name}


132
def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
133 134 135
    outputs = parse_plain_list(outputs, sep=",")
    output_items = []
    for output in outputs:
136
        output_items.append(parse_output(op_name, output))
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    return output_items


def parse_infer_meta(infer_meta: Dict[str, Any]) -> Dict[str, Any]:
    infer_meta = copy(infer_meta)  # to prevent mutating the input
    if "param" not in infer_meta:
        infer_meta["param"] = None
    return infer_meta


def parse_candidates(s: str) -> Dict[str, Any]:
    "parse candidates joined by either '>'(ordered) or ','(unordered)"
    delimiter = ">" if ">" in s else ","
    ordered = delimiter == ">"
    candidates = parse_plain_list(s, delimiter)
    return {"ordered": ordered, "candidates": candidates}


def parse_plain_list(s: str, sep=",") -> List[str]:
    items = [item.strip() for item in s.strip().split(sep)]
    return items


160
def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
161 162 163 164 165 166
    # kernel :
    #    func : [], Kernel functions (example: scale, scale_sr)
    #    param : [], Input params of kernel
    #    backend : str, the names of param to choose the kernel backend, default is None
    #    layout : str, the names of param to choose the kernel layout, default is None
    #    data_type : str, the names of param to choose the kernel data_type, default is None
167
    #    dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
168
    kernel = {
169
        'func': [],  # up to 2 function names
170 171 172
        'param': None,
        'backend': None,
        'layout': None,
173
        'data_type': None,
174
        'dispatch': {},
175 176 177 178 179 180 181 182 183 184 185 186
    }
    if 'param' in kernel_config:
        kernel['param'] = kernel_config['param']

    if 'backend' in kernel_config:
        kernel['backend'] = parse_candidates(kernel_config["backend"])

    if 'layout' in kernel_config:
        kernel['layout'] = parse_candidates(kernel_config["layout"])

    if 'data_type' in kernel_config:
        kernel['data_type'] = parse_candidates(kernel_config["data_type"])
187 188

    kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
189 190
        kernel_config['func']
    )
191 192 193 194 195 196 197 198 199 200 201

    def parse_kernel_in_out_type(in_out_str):
        if len(in_out_str) == 0:
            return None
        tmp_in_out_list = in_out_str[1:-1].split('->')
        inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
        outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]

        # check the tensor type
        for item in inputs:
            assert item in [
202 203 204 205
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
206
            ], f"{op_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
207 208
        for item in outputs:
            assert item in [
209 210 211 212
                'dense',
                'selected_rows',
                'sparse_coo',
                'sparse_csr',
213
            ], f"{op_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
214 215 216 217 218 219

        return (inputs, outputs)

    for func_item in kernel_funcs:
        kernel['func'].append(func_item[0])
        kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
220 221
            func_item[1]
        )
222

223 224 225
    return kernel


226
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
227 228 229 230 231 232 233 234 235
    inplace_map = {}
    inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
    pairs = parse_plain_list(inplace_cfg)
    for pair in pairs:
        in_name, out_name = parse_plain_list(pair, sep="->")
        inplace_map[out_name] = in_name
    return inplace_map


236
def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
237 238 239 240 241 242 243 244 245
    invoke_config = invoke_config.strip()
    func, rest = invoke_config.split("(", 1)
    func = func.strip()
    args = rest.rstrip(")").strip()
    invocation = {"func": func, "args": args}
    return invocation


def extract_type_and_name(records: List[Dict]) -> List[Dict]:
246
    """extract type and name from forward call, it is simpler than forward op ."""
247 248 249
    extracted = [
        {"name": item["name"], "typename": item["typename"]} for item in records
    ]
250 251 252
    return extracted


253 254
def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
    # op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
255
    result = re.search(
256
        r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
257 258
        forward_config,
    )
259 260
    op = result.group("op")
    outputs = parse_outputs(op_name, result.group("outputs"))
261 262
    outputs = extract_type_and_name(outputs)

263
    inputs, attrs = parse_input_and_attr(op_name, result.group("args"))
264 265 266
    inputs = extract_type_and_name(inputs)
    attrs = extract_type_and_name(attrs)
    forward_cfg = {
267
        "name": op,
268 269
        "inputs": inputs,
        "attrs": attrs,
270
        "outputs": outputs,
271 272 273 274
    }
    return forward_cfg


275 276 277 278
def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
    op_name = op_entry[name_field]
    inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
    outputs = parse_outputs(op_name, op_entry["output"])
279 280 281 282 283 284 285

    # validate default value of DataType and DataLayout
    for attr in attrs:
        if "default_value" in attr:
            typename = attr["typename"]
            default_value = attr["default_value"]
            if typename == "DataType":
286 287
                assert (
                    "DataType" in default_value
288
                ), f"invalid DataType default value in {op_name}"
289
                # remove namespace
290
                default_value = default_value[default_value.find("DataType") :]
291 292
                attr["default_value"] = default_value
            elif typename == "DataLayout":
293 294
                assert (
                    "DataLayout" in default_value
295
                ), f"invalid DataLayout default value in {op_name}"
296 297 298
                default_value = default_value[
                    default_value.find("DataLayout") :
                ]
299 300 301 302 303 304 305 306 307
                attr["default_value"] = default_value

    input_names = [item["name"] for item in inputs]
    attr_names = [item["name"] for item in attrs]
    output_names = [item["name"] for item in outputs]

    # add optional tag for every input
    for input in inputs:
        input["optional"] = False
308 309
    if "optional" in op_entry:
        optional_args = parse_plain_list(op_entry["optional"])
310
        for name in optional_args:
311 312
            assert (
                name in input_names
313
            ), f"{op_name} has an optional input: '{name}' which is not an input."
314 315 316 317 318 319 320
        for input in inputs:
            if input["name"] in optional_args:
                input["optional"] = True

    # add intermediate tag for every output
    for output in outputs:
        output["intermediate"] = False
321 322
    if "intermediate" in op_entry:
        intermediate_outs = parse_plain_list(op_entry["intermediate"])
323
        for name in intermediate_outs:
324 325
            assert (
                name in output_names
326
            ), f"{op_name} has an intermediate output: '{name}' which is not an output."
327 328 329 330 331 332 333
        for output in outputs:
            if output["name"] in intermediate_outs:
                output["intermediate"] = True

    # add no_need_buffer for every input
    for input in inputs:
        input["no_need_buffer"] = False
334 335
    if "no_need_buffer" in op_entry:
        no_buffer_args = parse_plain_list(op_entry["no_need_buffer"])
336
        for name in no_buffer_args:
337 338
            assert (
                name in input_names
339
            ), f"{op_name} has an no buffer input: '{name}' which is not an input."
340 341 342 343 344 345 346 347
        for input in inputs:
            if input["name"] in no_buffer_args:
                input["no_need_buffer"] = True
    else:
        no_buffer_args = None

    # TODO(chenfeiyu): data_transform

348 349
    op = {
        "name": op_name,
350 351 352
        "inputs": inputs,
        "attrs": attrs,
        "outputs": outputs,
353
        "no_need_buffer": no_buffer_args,
354 355
    }

356 357
    # invokes another op ?
    is_base_op = "invoke" not in op_entry
358

359
    if is_base_op:
360
        # kernel
361
        kernel = parse_kernel(op_name, op_entry["kernel"])
362 363 364 365
        if kernel["param"] is None:
            kernel["param"] = input_names + attr_names

        # infer meta
366
        infer_meta = parse_infer_meta(op_entry["infer_meta"])
367 368 369 370
        if infer_meta["param"] is None:
            infer_meta["param"] = copy(kernel["param"])

        # inplace
371 372
        if "inplace" in op_entry:
            inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
373 374
        else:
            inplace_pairs = None
375
        op.update(
376 377 378 379 380 381
            {
                "infer_meta": infer_meta,
                "kernel": kernel,
                "inplace": inplace_pairs,
            }
        )
382 383
    else:
        # invoke
384 385
        invoke = parse_invoke(op_name, op_entry["invoke"])
        op["invoke"] = invoke
386 387

    # backward
388 389
    if "backward" in op_entry:
        backward = op_entry["backward"]
390 391
    else:
        backward = None
392
    op["backward"] = backward
393

394 395 396 397 398
    # forward for backward_ops
    is_backward_op = name_field == "backward_op"
    if is_backward_op:
        if "forward" in op_entry:
            forward = parse_forward(op_name, op_entry["forward"])
399
            # validate_fb
400
            validate_backward_inputs(
401
                op_name, forward["inputs"], forward["outputs"], inputs
402
            )
403 404
            validate_backward_attrs(op_name, forward["attrs"], attrs)
            validate_backward_outputs(op_name, forward["inputs"], outputs)
405 406
        else:
            forward = None
407 408
        op["forward"] = forward
    return op
409 410


411
def validate_backward_attrs(op, forward_attrs, backward_attrs):
412 413 414
    if len(forward_attrs) >= len(backward_attrs):
        return
    num_exceptional_attrs = len(backward_attrs) - len(forward_attrs)
415 416
    # this is a not-that-clean trick to allow backward op to has more attrs
    # than the forward op , as long as they all have default value
417
    for i in range(-num_exceptional_attrs, 0):
418 419
        assert (
            "default_value" in backward_attrs[i]
420
        ), f"{op } has exceptional attr without default value"
421 422


423
def validate_backward_inputs(
424
    op, forward_inputs, forward_outputs, backward_inputs
425
):
426 427 428 429 430
    foward_input_names = [item["name"] for item in forward_inputs]
    forward_output_names = [item["name"] for item in forward_outputs]
    backward_input_names = [item["name"] for item in backward_inputs]

    assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
431
        forward_output_names
432
    ), f"{op } has too many inputs."
433 434


435
def validate_backward_outputs(op, forward_inputs, backward_outputs):
436
    assert len(backward_outputs) <= len(
437
        forward_inputs
438
    ), f"{op } has too many outputs"
439 440


441 442 443 444
def cross_validate(ops):
    for name, op in ops.items():
        if "forward" in op:
            fw_call = op["forward"]
445
            fw_name = fw_call["name"]
446
            if fw_name not in ops:
447
                print(
448
                    f"Something Wrong here, this backward op ({name})'s forward op ({fw_name}) does not exist."
449 450
                )
            else:
451 452
                fw_op = ops[fw_name]
                if "backward" not in fw_op or fw_op["backward"] is None:
453
                    print(
454
                        f"Something Wrong here, {name}'s forward op ({fw_name}) does not claim {name} as its backward."
455 456
                    )
                else:
457
                    assert (
458
                        fw_op["backward"] == name
459
                    ), f"{name}: backward and forward name mismatch"
460 461

                assert len(fw_call["inputs"]) <= len(
462 463 464
                    fw_op["inputs"]
                ), f"{name}: forward call has more inputs than the op "
                for (input, input_) in zip(fw_call["inputs"], fw_op["inputs"]):
465 466 467
                    assert (
                        input["typename"] == input_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"
468 469

                assert len(fw_call["attrs"]) <= len(
470 471 472
                    fw_op["attrs"]
                ), f"{name}: forward call has more attrs than the op "
                for (attr, attr_) in zip(fw_call["attrs"], fw_op["attrs"]):
473 474 475 476 477 478
                    if attr["typename"] == "Scalar":
                        # special case for Scalar, fw_call can omit the type
                        assert re.match(
                            r"Scalar(\(\w+\))*", attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
                    else:
479 480 481
                        assert (
                            attr["typename"] == attr_["typename"]
                        ), f"type mismatch in {name} and {fw_name}"
482 483

                assert len(fw_call["outputs"]) == len(
484 485
                    fw_op["outputs"]
                ), f"{name}: forward call has more outputs than the op "
486
                for (output, output_) in zip(
487
                    fw_call["outputs"], fw_op["outputs"]
488 489 490 491
                ):
                    assert (
                        output["typename"] == output_["typename"]
                    ), f"type mismatch in {name} and {fw_name}"