sparse_api_gen.py 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import argparse
import re

20
from api_gen import ForwardAPI
21
from api_base import PREFIX_TENSOR_NAME
22 23


24
class SparseAPI(ForwardAPI):
25 26 27 28
    def __init__(self, api_item_yaml):
        super(SparseAPI, self).__init__(api_item_yaml)

    def gene_api_declaration(self):
29 30 31 32
        return f"""
// {", ".join(self.outputs['names'])}
{super(SparseAPI, self).gene_api_declaration()}
"""
33

34 35 36 37 38 39 40
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
41
        kernel_output = []
42 43
        output_names = []
        output_create = ""
44
        return_type = self.get_return_type_with_intermediate(inplace_flag)
45 46 47
        output_type_map = {
            'dense': 'TensorType::DENSE_TENSOR',
            'sparse_coo': 'TensorType::SPARSE_COO',
48
            'sparse_csr': 'TensorType::SPARSE_CSR',
49
        }
50

51
        if len(out_dtype_list) == 1:
52
            kernel_output.append('kernel_out')
53
            output_names.append('kernel_out')
54 55 56 57 58 59 60
            inplace_assign = (
                " = " + self.inplace_map[self.outputs['names'][0]]
                if inplace_flag
                and self.inplace_map is not None
                and self.outputs['names'][0] in self.inplace_map
                else ""
            )
61
            output_create = f"""
62
    {return_type} api_output{inplace_assign};
63
    auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});"""
64

65
        elif len(out_dtype_list) > 1:
66
            output_create = f"""
67
    {return_type} api_output;"""
68 69 70

            if inplace_flag:
                output_create = f"""
71
    {return_type} api_output{{"""
72 73 74

                for out_name in self.outputs['names']:
                    if out_name in self.inplace_map:
75 76 77
                        output_create = (
                            output_create + self.inplace_map[out_name] + ', '
                        )
78 79 80
                    else:
                        output_create += 'Tensor(), '
                output_create = output_create[:-2] + '};'
81

82
            for i in range(len(out_dtype_list)):
83
                kernel_output.append(f'kernel_out_{i}')
84
                output_names.append(f'kernel_out_{i}')
85 86 87
                output_create = (
                    output_create
                    + f"""
88
    auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
89
                )
90 91 92 93

        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
94 95 96
                    self.api
                )
            )
97 98 99 100 101

        return kernel_output, output_names, output_create

    def gen_sparse_kernel_context(self, kernel_output_names):
        input_trans_map = {
102 103 104
            'const Tensor&': 'const phi::TenseBase&',
            'const std::vector<Tensor>&': 'const std::vector<phi::TenseBase>&',
            'const paddle::optional<Tensor>&': 'paddle::optional<const phi::TenseBase&>',
105 106 107
        }
        out_trans_map = {
            'Tensor': 'phi::TenseBase*',
108
            'std::vector<Tensor>': 'std::vector<phi::TenseBase*>',
109 110 111 112 113 114 115 116 117 118 119 120 121
        }
        input_names = self.inputs['names']
        input_infos = self.inputs['input_info']

        attr_names = self.attrs['names']
        kernel_param = self.kernel['param']
        if kernel_param is None:
            kernel_param = input_names + attr_names

        kernel_context_code = ""
        for param in kernel_param:
            if param in input_names:
                if param in self.optional_vars:
122 123 124
                    kernel_context_code = (
                        kernel_context_code
                        + f"""
125
    kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
126
                    )
127
                else:
128 129 130
                    kernel_context_code = (
                        kernel_context_code
                        + f"""
131
    kernel_context.EmplaceBackInput({param}.impl().get());"""
132
                    )
133 134 135 136

                continue
            if param in attr_names:
                # set attr for kernel_context
137 138
                if 'IntArray' in self.attrs['attr_info'][param][0]:
                    param = 'phi::IntArray(' + param + ')'
139 140 141 142 143 144
                elif 'Scalar' in self.attrs['attr_info'][param][0]:
                    param = 'phi::Scalar(' + param + ')'
            elif isinstance(param, bool):
                param = str(param).lower()
            else:
                param + str(param) + ", "
145 146 147
            kernel_context_code = (
                kernel_context_code
                + f"""
148
    kernel_context.EmplaceBackAttr({param});"""
149
            )
150 151

        for out_name in kernel_output_names:
152 153 154
            kernel_context_code = (
                kernel_context_code
                + f"""
155
    kernel_context.EmplaceBackOutput({out_name});"""
156
            )
157 158 159

        return kernel_context_code

160 161 162 163 164 165
    def prepare_input(self):
        input_names = self.inputs['names']
        input_types = self.inputs['tensor_type']
        attr_names = self.attrs['names']
        infer_meta = self.infer_meta

166 167 168 169 170
        infer_meta_params = (
            infer_meta['param']
            if infer_meta['param'] is not None
            else input_names + attr_names
        )
171 172 173 174 175

        create_input_var_code = ""
        tensor_type_map = {
            'dense': 'phi::DenseTensor',
            'sparse_coo': 'phi::SparseCooTensor',
176
            'sparse_csr': 'phi::SparseCsrTensor',
177 178 179 180 181
        }
        for param in infer_meta_params:
            if param in input_names:
                var_name = "auto " + PREFIX_TENSOR_NAME + param + " = "
                if self.inputs['input_info'][param] == "const Tensor&":
182 183 184
                    create_input_var_code = (
                        create_input_var_code + var_name + param + ".impl();\n"
                    )
185 186 187 188 189 190 191
                elif param in self.optional_vars:
                    tensor_type = 'phi::DenseTensor'
                    for name, input_type in zip(input_names, input_types):
                        if param == name:
                            tensor_type = tensor_type_map[input_type]
                            break
                    optional_var = "paddle::optional<" + tensor_type + ">("
192 193 194 195 196 197 198 199 200 201 202 203 204 205
                    create_input_var_code = (
                        create_input_var_code
                        + var_name
                        + param
                        + " ? "
                        + optional_var
                        + "*static_cast<"
                        + tensor_type
                        + "*>((*"
                        + param
                        + ").impl().get())) : "
                        + optional_var
                        + "paddle::none);\n"
                    )
206 207
        return f"""{create_input_var_code}"""

208
    def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
209
        _, kernel_output_names, output_create = self.gene_output(
210 211
            self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag
        )
212 213

        kernel_context_code = self.gen_sparse_kernel_context(
214 215 216 217 218 219 220
            kernel_output_names
        )
        return_code = (
            ""
            if len(self.gene_return_code()) == 0
            else "  " + self.gene_return_code()
        )
221
        return f"""
222
    VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
223
    auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
224
        "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}});
225
    const auto& phi_kernel = kernel_result.kernel;
226
    VLOG(6) << "{self.api} api sparse kernel: " << phi_kernel;
227

228
    auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
229
    auto kernel_context = phi::KernelContext(dev_ctx);
230
{output_create}
231 232
{self.prepare_input()}
{self.gene_infer_meta(kernel_output_names, '')}
233
{kernel_context_code}
234 235 236 237
    phi_kernel(&kernel_context);
  {return_code}"""

    def get_condition_code(self, kernel_name):
238 239 240
        assert self.kernel['dispatch'][
            kernel_name
        ], f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_ops.yaml."
241 242 243
        input_types = self.kernel['dispatch'][kernel_name][0]
        sparse_type_map = {
            'sparse_coo': 'DataLayout::SPARSE_COO',
244
            'sparse_csr': 'DataLayout::SPARSE_CSR',
245 246
        }
        condition_list = []
247
        tensor_type_list = []
248 249
        for i, in_type in enumerate(input_types):
            if in_type == "dense":
250 251 252 253 254 255 256 257
                if self.inputs['names'][i] in self.optional_vars:
                    condition_list.append(
                        f"(!{self.inputs['names'][i]} || phi::DenseTensor::classof({self.inputs['names'][i]}->impl().get()))"
                    )
                else:
                    condition_list.append(
                        f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())"
                    )
258
            else:
259 260
                if in_type == 'sparse_coo':
                    condition_list.append(
261 262
                        f"{self.inputs['names'][i]}.is_sparse_coo_tensor()"
                    )
263 264
                else:
                    condition_list.append(
265 266
                        f"{self.inputs['names'][i]}.is_sparse_csr_tensor()"
                    )
267 268 269
            tensor_type_list.append(in_type)
        self.inputs['tensor_type'] = tensor_type_list

270 271 272 273 274 275 276 277
        return " && ".join(condition_list)

    def gene_dispatch_code(self, kernel_name, inplace_flag=False):
        return f"""
  if ({self.get_condition_code(kernel_name)}) {{
{self.gen_sparse_kernel_code(kernel_name, inplace_flag)}
  }}
"""
278 279

    def gene_base_api_code(self, inplace_flag=False):
280 281 282
        api_func_name = self.get_api_func_name()
        if inplace_flag and api_func_name[-1] != '_':
            api_func_name += '_'
283 284
        kernel_dispatch_code = f"{self.gene_kernel_select()}\n"
        for kernel_name in self.kernel['func']:
285
            kernel_dispatch_code += self.gene_dispatch_code(
286 287
                kernel_name, inplace_flag
            )
288

289
        return f"""
290
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
291 292 293
{kernel_dispatch_code}
  PADDLE_THROW(phi::errors::Unimplemented(
          "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
294 295 296 297 298 299 300 301 302 303
}}
"""


def header_include():
    return """
#include <tuple>

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
304
#include "paddle/phi/common/int_array.h"
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
#include "paddle/utils/optional.h"
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
320 321 322 323 324 325 326 327 328
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/utils/none.h"

#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
329 330 331 332
"""


def api_namespace():
333 334
    return (
        """
335 336 337 338
namespace paddle {
namespace experimental {
namespace sparse {

339 340
""",
        """
341 342 343 344

}  // namespace sparse
}  // namespace experimental
}  // namespace paddle
345 346
""",
    )
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367


def generate_api(api_yaml_path, header_file_path, source_file_path):

    with open(api_yaml_path, 'r') as f:
        apis = yaml.load(f, Loader=yaml.FullLoader)
    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

    include_header_file = "paddle/phi/api/include/sparse_api.h"
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for api in apis:
        sparse_api = SparseAPI(api)
368 369
        if sparse_api.is_dygraph_api:
            sparse_api.is_dygraph_api = False
370 371 372 373 374 375 376 377 378 379 380 381
        header_file.write(sparse_api.gene_api_declaration())
        source_file.write(sparse_api.gene_api_code())

    header_file.write(namespace[1])
    source_file.write(namespace[1])

    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
        description='Generate PaddlePaddle C++ Sparse API files'
    )
    parser.add_argument(
        '--api_yaml_path',
        help='path to sparse api yaml file',
        default='paddle/phi/api/yaml/sparse_ops.yaml',
    )

    parser.add_argument(
        '--api_header_path',
        help='output of generated api header code file',
        default='paddle/phi/api/include/sparse_api.h',
    )

    parser.add_argument(
        '--api_source_path',
        help='output of generated api source code file',
        default='paddle/phi/api/lib/sparse_api.cc',
    )
401 402 403 404 405 406 407 408 409 410 411 412

    options = parser.parse_args()

    api_yaml_path = options.api_yaml_path
    header_file_path = options.api_header_path
    source_file_path = options.api_source_path

    generate_api(api_yaml_path, header_file_path, source_file_path)


if __name__ == '__main__':
    main()