reduce_lib_size_util.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script simply removes grad kernels. You should use this script
when cmake ON_INFER=ON, which can greatly reduce the volume of the inference library.
"""

import glob
import os


def is_balanced(content):
    """
    Check whether sequence contains valid parenthesis.
    Args:
       content (str): content of string.

    Returns:
        boolean: True if sequence contains valid parenthesis.
    """

    if content.find('{') == -1:
        return False
    stack = []
    push_chars, pop_chars = '({', ')}'
    for c in content:
        if c in push_chars:
            stack.append(c)
        elif c in pop_chars:
            if not len(stack):
                return False
            else:
                stack_top = stack.pop()
                balancing_bracket = push_chars[pop_chars.index(c)]
                if stack_top != balancing_bracket:
                    return False
    return not stack


def grad_kernel_definition(content, kernel_pattern, grad_pattern):
    """
    Args:
       content(str): file content
       kernel_pattern(str): kernel pattern
       grad_pattern(str): grad pattern

    Returns:
        (list, int): grad kernel definitions in file and count.
    """

    results = []
    count = 0
    start = 0
    lens = len(content)
    while True:
        index = content.find(kernel_pattern, start)
        if index == -1:
            return results, count
        i = index + 1
        while i <= lens:
            check_str = content[index:i]
            if is_balanced(check_str):
                if check_str.find(grad_pattern) != -1:
                    results.append(check_str)
                    count += 1
                start = i
                break
            i += 1
        else:
            return results, count


def remove_grad_kernels(dry_run=False):
    """
    Args:
       dry_run(bool): whether just print

    Returns:
        int: number of kernel(grad) removed
    """

    pd_kernel_pattern = 'PD_REGISTER_STRUCT_KERNEL'
    register_op_pd_kernel_count = 0
    matches = []

    tool_dir = os.path.dirname(os.path.abspath(__file__))
    all_op = glob.glob(
        os.path.join(tool_dir, '../paddle/fluid/operators/**/*.cc'),
        recursive=True,
    )
    all_op += glob.glob(
        os.path.join(tool_dir, '../paddle/fluid/operators/**/*.cu'),
        recursive=True,
    )

    for op_file in all_op:
        with open(op_file, 'r', encoding='utf-8') as f:
            content = ''.join(f.readlines())

            pd_kernel, pd_kernel_count = grad_kernel_definition(
                content, pd_kernel_pattern, '_grad,'
            )

            register_op_pd_kernel_count += pd_kernel_count

            matches.extend(pd_kernel)

        for to_remove in matches:
            content = content.replace(to_remove, '')
            if dry_run:
                print(op_file, to_remove)

        if not dry_run:
            with open(op_file, 'w', encoding='utf-8') as f:
                f.write(content)

    return register_op_pd_kernel_count