generate_sparse_op.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path

import yaml
20
from filters import (
21
    cartesian_prod_mapping,
J
Jiabin Yang 已提交
22
    to_composite_grad_opmaker_name,
23
    to_input_name,
24 25
    to_int_array_tensor_name,
    to_int_array_tensors_name,
26 27 28 29
    to_op_attr_type,
    to_opmaker_name,
    to_opmaker_name_cstr,
    to_pascal_case,
30
    to_scalar_tensor_name,
31
    to_variable_names,
32
)
33
from generate_op import add_fluid_name, process_invoke_op
34 35
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
36
from tests import (
37
    is_base_op,
38
    is_composite_op,
39
    is_initializer_list,
40 41
    is_scalar,
    is_vec,
42 43 44
    supports_inplace,
    supports_no_need_buffer,
)
45 46

file_loader = FileSystemLoader(Path(__file__).parent / "templates")
47 48 49 50 51 52 53 54
env = Environment(
    loader=file_loader,
    keep_trailing_newline=True,
    trim_blocks=True,
    lstrip_blocks=True,
    undefined=StrictUndefined,
    extensions=['jinja2.ext.do'],
)
55 56 57
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
58 59 60
env.filters["to_scalar_tensor_name"] = to_scalar_tensor_name
env.filters["to_int_array_tensor_name"] = to_int_array_tensor_name
env.filters["to_int_array_tensors_name"] = to_int_array_tensors_name
61 62 63
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
J
Jiabin Yang 已提交
64
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
65
env.filters["to_variable_names"] = to_variable_names
66
env.tests["base_op"] = is_base_op
67
env.tests["composite_op"] = is_composite_op
68 69 70 71 72 73 74
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer


75 76 77 78 79
def restruct_io(op):
    op["input_dict"] = to_named_dict(op["inputs"])
    op["attr_dict"] = to_named_dict(op["attrs"])
    op["output_dict"] = to_named_dict(op["outputs"])
    return op
80 81 82 83 84


SPARSE_OP_PREFIX = 'sparse_'


85 86 87 88 89
def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
    with open(op_yaml_path, "rt") as f:
        ops = yaml.safe_load(f)
        ops = [restruct_io(op) for op in ops]
    forward_op_dict = to_named_dict(ops)
90 91

    with open(backward_yaml_path, "rt") as f:
92 93 94 95 96
        backward_ops = yaml.safe_load(f)
        backward_ops = [restruct_io(op) for op in backward_ops]
    backward_op_dict = to_named_dict(backward_ops)

    for op in ops:
97 98
        if op['name'][-1] == '_':
            op['name'] = op['name'][:-1]
99 100 101 102
        op['op_name'] = SPARSE_OP_PREFIX + op['name']
        op['name'] = op['op_name']
        if op["backward"] is not None:
            op["backward"] = SPARSE_OP_PREFIX + op["backward"]
103 104 105
        add_fluid_name(op["inputs"])
        add_fluid_name(op["attrs"])
        add_fluid_name(op["outputs"])
106 107 108
    for bw_op in backward_ops:
        bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
        bw_op['name'] = bw_op['op_name']
109 110 111 112 113 114
        add_fluid_name(bw_op["inputs"])
        add_fluid_name(bw_op["attrs"])
        add_fluid_name(bw_op["outputs"])
        add_fluid_name(bw_op["forward"]["inputs"])
        add_fluid_name(bw_op["forward"]["attrs"])
        add_fluid_name(bw_op["forward"]["outputs"])
115 116 117
        if 'invoke' in bw_op:
            bw_op['invoke']['args'] = [
                param.strip() for param in bw_op['invoke']['args'].split(',')
118 119 120
            ]

    # prepare for invoke case
121 122 123 124 125 126
    process_invoke_op(forward_op_dict, backward_op_dict)
    for bw_op in backward_ops:
        if 'invoke' in bw_op:
            if bw_op['invoke']['func'] in forward_op_dict:
                bw_op['invoke']['func'] = (
                    SPARSE_OP_PREFIX + bw_op['invoke']['func']
127
                )
128

129 130 131 132 133 134 135 136
    # fill backward field for an op if another op claims it as forward
    for name, backward_op in backward_op_dict.items():
        forward_name = backward_op["forward"]["name"]
        if forward_name in backward_op_dict:
            forward_op = backward_op_dict[forward_name]
            if forward_op["backward"] is None:
                forward_op["backward"] = name
            forward_op["backward"] = SPARSE_OP_PREFIX + forward_op["backward"]
137

138 139 140
    op_dict = {}
    op_dict.update(forward_op_dict)
    op_dict.update(backward_op_dict)
141

142
    if len(ops) == 0 and len(backward_ops) == 0:
143 144 145 146 147 148 149 150
        if os.path.isfile(output_op_path):
            os.remove(output_op_path)
        if os.path.isfile(output_arg_map_path):
            os.remove(output_arg_map_path)
        return

    op_template = env.get_template('sparse_op.c.j2')
    with open(output_op_path, "wt") as f:
151
        msg = op_template.render(
J
Jiabin Yang 已提交
152 153 154
            ops=ops,
            backward_ops=backward_ops,
            op_dict=op_dict,
155
        )
156 157 158 159
        f.write(msg)

    ks_template = env.get_template('sparse_ks.c.j2')
    with open(output_arg_map_path, 'wt') as f:
160
        msg = ks_template.render(ops=ops, backward_ops=backward_ops)
161 162 163 164 165
        f.write(msg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
166
        description="Generate operator file from op yaml."
167 168 169 170 171 172 173 174 175 176 177 178
    )
    parser.add_argument(
        '--ops_yaml_path', type=str, help="parsed sparse ops yaml file."
    )
    parser.add_argument(
        '--backward_ops_yaml_path',
        type=str,
        help="parsed backward sparse ops yaml file.",
    )
    parser.add_argument(
        "--output_op_path", type=str, help="path to save generated operators."
    )
179 180 181
    parser.add_argument(
        "--output_arg_map_path",
        type=str,
182 183
        help="path to save generated argument mapping functions.",
    )
184 185

    args = parser.parse_args()
186 187 188 189 190 191
    main(
        args.ops_yaml_path,
        args.backward_ops_yaml_path,
        args.output_op_path,
        args.output_arg_map_path,
    )