codegen_utils.py 17.2 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

17 18
import yaml

19 20 21
####################
# Global Variables #
####################
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
ops_to_fill_zero_for_empty_grads = set(
    [
        "split_grad",
        "split_with_num_grad",
        "rnn_grad",
        "matmul_double_grad",
        "matmul_triple_grad",
        "sigmoid_double_grad",
        "sigmoid_triple_grad",
        "add_double_grad",
        "add_triple_grad",
        "multiply_grad",
        "multiply_double_grad",
        "multiply_triple_grad",
        "conv2d_grad_grad",
        "batch_norm_double_grad",
38
        "tanh_grad",
39 40
        "tanh_double_grad",
        "tanh_triple_grad",
41 42
        "sin_double_grad",
        "sin_triple_grad",
43 44
        "cos_double_grad",
        "cos_triple_grad",
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        "subtract_double_grad",
        "divide_double_grad",
        "log_double_grad",
        "elu_double_grad",
        "leaky_relu_double_grad",
        "sqrt_double_grad",
        "rsqrt_double_grad",
        "square_double_grad",
        "celu_double_grad",
        "pad_double_grad",
        "pad3d_double_grad",
        "squeeze_double_grad",
        "unsqueeze_double_grad",
        "instance_norm_double_grad",
        "conv3d_double_grad",
        "depthwise_conv2d_grad_grad",
        "concat_double_grad",
        "expand_grad",
        "argsort_grad",
    ]
)
66 67 68 69 70 71 72 73

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
core_ops_returns_info = {}
core_ops_args_info = {}
core_ops_args_type_info = {}

yaml_types_mapping = {
74 75 76 77 78 79 80 81 82 83
    'int': 'int',
    'int32_t': 'int32_t',
    'int64_t': 'int64_t',
    'size_t': 'size_t',
    'float': 'float',
    'double': 'double',
    'bool': 'bool',
    'str': 'std::string',
    'str[]': 'std::vector<std::string>',
    'float[]': 'std::vector<float>',
84
    'bool[]': 'std::vector<bool>',
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    'Place': 'paddle::Place',
    'DataLayout': 'phi::DataLayout',
    'DataType': 'paddle::experimental::DataType',
    'int64_t[]': 'std::vector<int64_t>',
    'int[]': 'std::vector<int>',
    'Tensor': 'Tensor',
    'Tensor[]': 'std::vector<Tensor>',
    'Tensor[Tensor[]]': 'std::vector<std::vector<Tensor>>',
    'Scalar': 'paddle::experimental::Scalar',
    'Scalar(int)': 'paddle::experimental::Scalar',
    'Scalar(int64_t)': 'paddle::experimental::Scalar',
    'Scalar(float)': 'paddle::experimental::Scalar',
    'Scalar(double)': 'paddle::experimental::Scalar',
    'Scalar[]': 'std::vector<phi::Scalar>',
    'IntArray': 'paddle::experimental::IntArray',
100 101 102
}


103 104 105
#########################
#  File Reader Helpers  #
#########################
106 107 108 109
def AssertMessage(lhs_str, rhs_str):
    return f"lhs: {lhs_str}, rhs: {rhs_str}"


110 111
def ReadFwdFile(filepath):
    f = open(filepath, 'r')
112
    # empty file loaded by yaml is None
113 114
    contents = yaml.load(f, Loader=yaml.FullLoader)
    f.close()
115
    return contents if contents is not None else []
116 117 118 119 120 121


def ReadBwdFile(filepath):
    f = open(filepath, 'r')
    contents = yaml.load(f, Loader=yaml.FullLoader)
    ret = {}
122 123
    if contents is not None:
        for content in contents:
124
            assert 'backward_op' in content.keys(), AssertMessage(
125 126
                'backward_op', content.keys()
            )
127 128
            if 'backward_op' in content.keys():
                api_name = content['backward_op']
129 130

            ret[api_name] = content
131 132 133 134
    f.close()
    return ret


135 136 137
##############################
#  Generic Helper Functions  #
##############################
138 139 140 141 142 143 144 145 146 147
def FindGradName(string):
    return string + "_grad"


def FindForwardName(string):
    if not string.endswith("_grad"):
        return None
    return string[:-5]


148 149 150 151
def IsGradName(string):
    return string.endswith("_grad")


152 153 154 155 156 157 158 159 160
def IsPlainTensorType(string):
    plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
    if string in plain_tensor_types:
        return True
    return False


def IsVectorTensorType(string):
    vector_tensor_types = [
161 162
        'std::vector<std::vector<Tensor>>',
        'std::vector<Tensor>',
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    ]
    if string in vector_tensor_types:
        return True
    return False


def GetSavedName(string):
    return string + "_"


def GetConstReference(string):
    ret = string
    if not string.startswith("const "):
        ret = "const " + string
    if not string.endswith("&"):
        ret += "&"
    return ret


def RemoveConstAndReference(string):
    ret = string
    if string.startswith("const "):
        ret = ret[6:]
    if string.endswith("&"):
        ret = ret[:-1]

    return ret


def GetGradNodeName(string):
193 194 195 196 197 198 199 200 201 202
    def str2Hump(text):
        arr = filter(None, text.split('_'))
        res = ''
        for i in arr:
            res = res + i[0].upper() + i[1:]
        return res

    string = str2Hump(string)
    if string.rfind("Grad") == (len(string) - 4):
        string = string[:-4]
J
Jiabin Yang 已提交
203
    return f"{string}GradNode"
204 205 206


def GetDygraphForwardFunctionName(string):
J
Jiabin Yang 已提交
207 208 209 210 211 212 213 214
    return f"{string}_ad_func"


def GetDygraphLogName(string):
    def str2Hump(text):
        arr = filter(None, text.split('_'))
        res = ''
        for i in arr:
215
            res = res + i.lower()
J
Jiabin Yang 已提交
216 217 218 219
        return res

    string = str2Hump(string)
    return string
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244


def GetIntermediateAPIFunctionName(string):
    return string + "_intermediate"


def GetAutoGradMetaName(string):
    return f"{string}_autograd_meta"


def GetAutoGradMetaVectorName(string):
    return f"{string}_autograd_meta_vec"


def RemoveSpecialSymbolsInName(string):
    # Remove any name after '@'
    ret = string.split("@")[0]
    return ret


def RecoverBaseNameOfInplaceFunction(function_name):
    return function_name[:-1]


def GetInplacedFunctionName(function_name):
Z
zyfncg 已提交
245 246 247 248
    inplace_func_name = function_name
    if inplace_func_name[-1] != '_':
        inplace_func_name += '_'
    return inplace_func_name
249 250 251


def GetForwardFunctionName(string):
J
Jiabin Yang 已提交
252
    return f"{string}_ad_func"
253 254


255
def GetIndent(num):
256
    tab = "  "
257 258 259
    return "".join([tab for i in range(num)])


260 261 262
##################
#  Yaml Parsers  #
##################
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
def ParseYamlArgs(string):
    # Example: const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y

    # inputs_list = [ [arg_name, arg_type, orig_position], ...]
    inputs_list = []
    # attrs_list = [ [arg_name, arg_type, default_value, orig_position], ...]
    attrs_list = []

    args = [x.strip() for x in string.strip().split(",")]
    atype = r'((const )?\S+) '
    aname = r'(.*)'
    pattern = f'{atype}{aname}'
    for i in range(len(args)):
        arg = args[i]
        m = re.search(pattern, arg)
        arg_type = m.group(1).strip()
        arg_name = m.group(3).split("=")[0].strip()
280 281 282 283 284 285 286 287
        default_value = (
            m.group(3).split("=")[1].strip()
            if len(m.group(3).split("=")) > 1
            else None
        )

        assert (
            arg_type in yaml_types_mapping.keys()
288
        ), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
289 290
        if arg_type in ["DataType", "DataLayout"] and default_value is not None:
            default_value = f"paddle::experimental::{default_value}"
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        arg_type = yaml_types_mapping[arg_type]

        arg_name = RemoveSpecialSymbolsInName(arg_name)
        if "Tensor" in arg_type:
            assert default_value is None
            inputs_list.append([arg_name, arg_type, i])
        else:
            attrs_list.append([arg_name, arg_type, default_value, i])

    return inputs_list, attrs_list


def ParseYamlReturns(string):
    # Example0: Tensor(out), Tensor(out1)
    # Example1: Tensor, Tensor
    # Example2: Tensor[](out), Tensor

    # list = [ [ret_name, ret_type, orig_position], ...]
    returns_list = []

    returns = [x.strip() for x in string.strip().split(",")]

    for i in range(len(returns)):
314
        ret = returns[i].split("{")[0].strip()
315 316 317 318 319 320 321 322 323 324

        ret_name = ""
        if "(" in ret and ")" in ret:
            # Remove trailing ')'
            ret = ret[:-1]
            ret_type = ret.split("(")[0].strip()
            ret_name = ret.split("(")[1].strip()
        else:
            ret_type = ret.strip()

325 326
        assert (
            ret_type in yaml_types_mapping.keys()
327 328 329
        ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
        ret_type = yaml_types_mapping[ret_type]

330
        assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type)
331 332 333 334 335 336 337 338 339 340 341 342 343
        ret_name = RemoveSpecialSymbolsInName(ret_name)
        returns_list.append([ret_name, ret_type, i])

    return returns_list


def ParseYamlForwardFromBackward(string):
    # Example: matmul (const Tensor& x, const Tensor& y, bool transpose_x, bool transpose_y) -> Tensor(out)

    fname = r'(.*?)'
    wspace = r'\s*'
    fargs = r'(.*?)'
    frets = r'(.*)'
344 345 346
    pattern = (
        fr'{fname}{wspace}\({wspace}{fargs}{wspace}\){wspace}->{wspace}{frets}'
    )
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364

    m = re.search(pattern, string)
    function_name = m.group(1)
    function_args = m.group(2)
    function_returns = m.group(3)

    forward_inputs_list, forward_attrs_list = ParseYamlArgs(function_args)
    forward_returns_list = ParseYamlReturns(function_returns)

    return forward_inputs_list, forward_attrs_list, forward_returns_list


def ParseYamlForward(args_str, returns_str):
    # args Example: (const Tensor& x, const Tensor& y, bool transpose_x = false, bool transpose_y = false)
    # returns Example: Tensor, Tensor

    fargs = r'(.*?)'
    wspace = r'\s*'
365
    args_pattern = fr'^\({fargs}\)$'
366
    args_str = re.search(args_pattern, args_str.strip()).group(1)
367 368 369 370 371 372 373 374 375 376 377 378 379

    inputs_list, attrs_list = ParseYamlArgs(args_str)
    returns_list = ParseYamlReturns(returns_str)

    return inputs_list, attrs_list, returns_list


def ParseYamlBackward(args_str, returns_str):
    # args Example: (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
    # returns Example: Tensor(x_grad), Tensor(y_grad)

    fargs = r'(.*?)'
    wspace = r'\s*'
380
    args_pattern = fr'\({fargs}\)'
381 382 383 384 385 386 387 388
    args_str = re.search(args_pattern, args_str).group(1)

    inputs_list, attrs_list = ParseYamlArgs(args_str)
    returns_list = ParseYamlReturns(returns_str)

    return inputs_list, attrs_list, returns_list


389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
def ParseYamlInplaceInfo(string):
    # inplace_map_str: "(x -> out0), (y -> out2)"
    inplace_map = {}
    for pair in string.split(","):
        pair = pair.strip()
        if pair.startswith("("):
            pair = pair[1:]

        if pair.endswith(")"):
            pair = pair[:-1]

        key = pair.split("->")[0].strip()
        val = pair.split("->")[1].strip()
        inplace_map[key] = val
    return inplace_map


J
Jiabin Yang 已提交
406 407 408 409 410 411 412 413
def ParseYamlCompositeInfo(string):
    # example:  composite: fun(args1, args2, ...)
    fname = r'(.*?)'
    wspace = r'\s*'
    fargs = r'(.*?)'
    pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'

    m = re.search(pattern, string)
J
Jiabin Yang 已提交
414 415
    composite_fun_info = {}
    composite_fun_info.update({"name": m.group(1)})
J
Jiabin Yang 已提交
416 417
    func_args = m.group(2).split(",")
    for fun_arg in func_args:
J
Jiabin Yang 已提交
418 419 420 421
        if "args" in composite_fun_info:
            composite_fun_info["args"].append(fun_arg.strip())
        else:
            composite_fun_info.update({"args": [fun_arg.strip()]})
J
Jiabin Yang 已提交
422 423 424 425

    return composite_fun_info


426 427 428
####################
#  Generator Base  #
####################
429 430 431 432 433
class FunctionGeneratorBase:
    def __init__(self, forward_api_contents, namespace):
        self.forward_api_contents = forward_api_contents
        self.namespace = namespace

434 435 436
        self.is_forward_only = (
            False if 'backward' in forward_api_contents.keys() else True
        )
437

438 439
        self.forward_api_name = ""

440 441 442 443 444 445 446 447 448
        self.orig_forward_inputs_list = (
            []
        )  # [ [arg_name, arg_type, orig_position], ...]
        self.orig_forward_attrs_list = (
            []
        )  # [ [attr_name, attr_type, default_value, orig_position], ...]
        self.orig_forward_returns_list = (
            []
        )  # [ [ret_name, ret_type, orig_position], ...]
449 450

        # Processed Forward Data
451 452 453 454 455 456
        self.forward_inputs_position_map = (
            {}
        )  # { "name" : [type, fwd_position] }
        self.forward_outputs_position_map = (
            {}
        )  # { "name" : [type, fwd_position] }
457 458

        # Special Op Attributes
459 460
        self.optional_inputs = []  # [name, ...]
        self.no_need_buffers = []  # [name, ...]
J
Jiabin Yang 已提交
461 462 463
        self.composite_func_info = (
            {}
        )  # {name: func_name, args: [input_name, ...]}
464 465
        self.intermediate_outputs = []  # [name, ...]
        self.forward_inplace_map = {}  # {name : name, ...}
466

467
    def ParseForwardInplaceInfo(self):
468
        forward_api_contents = self.forward_api_contents
469 470
        if 'inplace' not in forward_api_contents.keys():
            return
471 472

        inplace_map_str = forward_api_contents['inplace']
473
        self.forward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)
474 475

    def ParseNoNeedBuffer(self):
476
        grad_api_contents = self.grad_api_contents
477

478 479
        if 'no_need_buffer' in grad_api_contents.keys():
            no_need_buffer_str = grad_api_contents['no_need_buffer']
480 481 482 483 484
            for name in no_need_buffer_str.split(","):
                name = name.strip()
                name = RemoveSpecialSymbolsInName(name)
                self.no_need_buffers.append(name.strip())

J
Jiabin Yang 已提交
485 486 487 488 489 490 491
    def ParseComposite(self):
        grad_api_contents = self.grad_api_contents

        if 'composite' in grad_api_contents.keys():
            composite_str = grad_api_contents['composite']
            self.composite_func_info = ParseYamlCompositeInfo(composite_str)

492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
    def ParseDispensable(self):
        forward_api_contents = self.forward_api_contents

        if 'optional' in forward_api_contents.keys():
            optional_inputs_str = forward_api_contents['optional']
            for name in optional_inputs_str.split(","):
                name = name.strip()
                name = RemoveSpecialSymbolsInName(name)
                self.optional_inputs.append(name)

    def ParseIntermediate(self):
        forward_api_contents = self.forward_api_contents

        if 'intermediate' in forward_api_contents.keys():
            intermediate_str = forward_api_contents['intermediate']
            for name in intermediate_str.split(","):
                name = name.strip()
                name = RemoveSpecialSymbolsInName(name)
                self.intermediate_outputs.append(name)

    def CollectOriginalForwardInfo(self):
        forward_api_contents = self.forward_api_contents

515
        self.forward_api_name = forward_api_contents['op']
516 517 518
        forward_args_str = forward_api_contents['args']
        forward_returns_str = forward_api_contents['output']

519 520
        assert (
            'op' in forward_api_contents.keys()
521
        ), "Unable to find \"op\" in forward_api_contents keys"
522 523
        assert (
            'args' in forward_api_contents.keys()
524
        ), "Unable to find \"args\" in forward_api_contents keys"
525 526
        assert (
            'output' in forward_api_contents.keys()
527 528 529
        ), "Unable to find \"output\" in forward_api_contents keys"

        # Collect Original Forward Inputs/Outputs and then perform validation checks
530 531 532 533 534 535 536 537 538
        (
            self.orig_forward_inputs_list,
            self.orig_forward_attrs_list,
            self.orig_forward_returns_list,
        ) = ParseYamlForward(forward_args_str, forward_returns_str)

    def DetermineForwardPositionMap(
        self, forward_inputs_list, forward_returns_list
    ):
539 540 541 542 543 544
        for i in range(len(forward_inputs_list)):
            forward_input = forward_inputs_list[i]
            input_name = forward_input[0]
            input_type = forward_input[1]
            input_pos = forward_input[2]

545
            self.forward_inputs_position_map[input_name] = [
546 547
                input_type,
                input_pos,
548
            ]
549 550 551

        for i in range(len(forward_returns_list)):
            forward_return = forward_returns_list[i]
552
            if len(forward_return[0]) == 0:
553
                if len(forward_returns_list) == 1:
554 555 556 557 558
                    return_name = "out"
                else:
                    return_name = "out_{}".format(i + 1)
            else:
                return_name = forward_return[0]
559 560 561
            return_type = forward_return[1]
            return_pos = forward_return[2]

562
            self.forward_outputs_position_map[return_name] = [
563 564
                return_type,
                return_pos,
565
            ]
566 567


568
class GeneratorBase:
569 570 571 572 573 574 575 576 577 578 579 580
    def __init__(self, api_yaml_path):
        self.namespace = ""
        self.api_yaml_path = api_yaml_path

        self.forward_api_list = []

    def ParseForwardYamlContents(self):
        api_yaml_path = self.api_yaml_path
        self.forward_api_list = ReadFwdFile(api_yaml_path)

    def InferNameSpace(self):
        api_yaml_path = self.api_yaml_path
581
        if re.search(r"sparse[a-zA-Z0-9_]*\.yaml", api_yaml_path):
582
            self.namespace = "sparse::"
583
        elif re.search(r"strings[a-zA-Z0-9_]*\.yaml", api_yaml_path):
584
            self.namespace = "strings::"