backward_api_gen.py 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import re

18
import yaml
19
from api_base import BaseAPI
20 21


22
class BackwardAPI(BaseAPI):
23
    def __init__(self, backward_item_yaml):
24
        super().__init__(backward_item_yaml)
25
        self.check_args(backward_item_yaml['forward'])
26
        self.no_need_buffer = self.parse_no_need_buffer(backward_item_yaml)
27 28

    def get_api_name(self, api_item_yaml):
29
        return api_item_yaml['backward_op']
30

31 32 33
    def parse_forward_config(self, forward_config):
        # api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
        result = re.search(
34
            r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
35 36
            forward_config,
        )
37
        api = result.group('op')
38 39 40 41 42
        (
            _,
            outputs,
            _,
        ) = self.parse_output(self.api, result.group('outputs'))
43
        outputs = [item.split('@')[0] for item in outputs]
44 45 46
        fw_inputs, fw_attrs = self.parse_input_and_attr(
            api, result.group('args')
        )
47

48
        return api, fw_inputs, fw_attrs, outputs
49

50 51 52 53 54 55 56 57 58
    def parse_no_need_buffer(self, api_item_yaml):
        no_need_buffer = []
        if 'no_need_buffer' in api_item_yaml:
            no_need_buffer = [
                item.strip()
                for item in api_item_yaml['no_need_buffer'].split(',')
            ]
        return no_need_buffer

59
    def check_args(self, forward_config):
60 61
        # parse the forward and backward config
        _, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config(
62 63
            forward_config
        )
64 65

        # check the inputs of backward
66
        for input in self.inputs['names']:
67
            if input not in fw_inputs['names'] and input not in fw_outputs:
68 69
                if input.endswith('_grad'):
                    original_name = input[:-5]
70 71 72
                    assert (
                        original_name in fw_outputs
                    ), f"{self.api} : Input Tensor error: the input tensor({input}) of backward should be an input or output or grad of output in forward api. \
73
                         Please check the forward of {self.api} in yaml."
74 75

        # check the attributes of backward
76
        for attr in self.attrs['names']:
77 78 79 80 81 82 83
            assert (
                attr in fw_attrs['names']
                and self.attrs['attr_info'][attr][0]
                == fw_attrs['attr_info'][attr][0]
            ) or self.attrs['attr_info'][attr][
                1
            ] is not None, f"{self.api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api or doesn't have default value. \
84
                 Please check the args of {self.api} in yaml."
85 86

        # check the output of backward
87 88 89
        assert len(self.outputs['types']) <= len(
            fw_inputs['names']
        ), f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
90
             Please check the output of {self.api} in yaml."
91

92
    def get_declare_args(self, inplace_flag=False):
93 94
        return self.get_define_args()

95
    def get_define_args(self, inplace_flag=False):
96 97
        out_type_map = {
            'Tensor': 'Tensor*',
98
            'std::vector<Tensor>': 'std::vector<Tensor*>',
99
        }
100
        intputs_and_attrs = super().get_define_args()
101 102
        outs = []
        for i, name in enumerate(self.outputs['names']):
103 104 105 106 107
            outs.append(
                out_type_map[self.outputs['types'][i]]
                + ' '
                + name.split('@')[0]
            )
108 109 110 111 112 113
        result = intputs_and_attrs + ', ' + ", ".join(outs)
        return result

    def gene_return_code(self):
        return ""

114 115 116 117
    def gene_api_declaration(self):
        if not self.is_base_api:
            invoke_func_name = self.invoke.split('(')[0]
            if (not invoke_func_name.endswith("_grad")) and (
118 119
                not invoke_func_name.endswith('_impl')
            ):
120 121 122 123 124 125 126
                return ""
        api_func_name = self.get_api_func_name()
        api_declaration = f"""
PADDLE_API void {api_func_name}({self.get_declare_args()});
"""
        return api_declaration

127 128 129 130 131 132 133 134 135 136 137 138 139
    def gene_kernel_backend_select(self):
        all_no_need_buffer = True
        for in_name in self.inputs['names']:
            if in_name not in self.no_need_buffer:
                all_no_need_buffer = False

        if all_no_need_buffer:
            return """
  kernel_backend = ParseBackend(egr::Controller::Instance().GetExpectedPlace());
"""
        else:
            return super().gene_kernel_backend_select()

140
    def get_return_type(self, inplace_flag=False):
141
        return 'void'
142

143 144 145 146 147 148 149
    def gene_output(
        self,
        out_dtype_list,
        out_tensor_type_list=None,
        code_indent='',
        inplace_flag=False,
    ):
150
        kernel_output = []
151
        output_names = []
Z
zyfncg 已提交
152 153
        output_create = ""

154
        if len(out_dtype_list) == 1:
155
            kernel_output.append('kernel_out')
156
            output_names.append('kernel_out')
157 158 159 160 161 162 163
            inplace_assign = (
                " = " + self.inplace_map[self.outputs['names'][0]]
                if inplace_flag
                and self.inplace_map is not None
                and self.outputs['names'][0] in self.inplace_map
                else ""
            )
164
            output_create = ""
165 166 167 168 169 170
            set_out_func = (
                'SetKernelOutput'
                if out_tensor_type_list is None
                or out_tensor_type_list[0] == 'dense'
                else 'SetSelectedRowsKernelOutput'
            )
171
            if out_dtype_list[0] == 'std::vector<Tensor>':
172 173 174 175 176 177
                assert (
                    self.outputs['out_size_expr'] is not None
                ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                output_create = (
                    output_create
                    + f"""
178
{code_indent}  auto kernel_out = {set_out_func}(&{self.outputs['names'][0]});"""
179
                )
180 181

            else:
182 183 184
                output_create = (
                    output_create
                    + f"""
Z
zyfncg 已提交
185
{code_indent}  auto kernel_out = {set_out_func}({self.outputs['names'][0]});"""
186
                )
Z
zyfncg 已提交
187

188
        elif len(out_dtype_list) > 1:
189
            output_create = ""
190
            for i, out_type_item in enumerate(out_dtype_list):
191
                kernel_output.append(f'kernel_out_{i}')
192
                output_names.append(f'kernel_out_{i}')
193 194 195 196 197 198
                set_out_func = (
                    'SetKernelOutput'
                    if out_tensor_type_list is None
                    or out_tensor_type_list[i] == 'dense'
                    else 'SetSelectedRowsKernelOutput'
                )
199
                if out_type_item == 'Tensor':
200 201 202 203 204 205 206 207
                    if (
                        inplace_flag
                        and self.inplace_map is not None
                        and self.outputs['names'][i] in self.inplace_map
                    ):
                        output_create = (
                            output_create
                            + f"""
208
{code_indent}  *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
209
                        )
210

211 212 213
                    output_create = (
                        output_create
                        + f"""
Z
zyfncg 已提交
214
{code_indent}  auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]});"""
215
                    )
216

217
                else:
218 219 220 221 222 223 224 225
                    if (
                        inplace_flag
                        and self.inplace_map is not None
                        and self.outputs['names'][i] in self.inplace_map
                    ):
                        output_create = (
                            output_create
                            + f"""
226
{code_indent}  *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
227 228 229 230 231 232 233 234
                        )

                    assert (
                        self.outputs['out_size_expr'][i] is not None
                    ), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
                    output_create = (
                        output_create
                        + f"""
235
{code_indent}  auto kernel_out_{i} = {set_out_func}(&{self.outputs['names'][i]});"""
236
                    )
Z
zyfncg 已提交
237 238 239 240

        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
241 242 243
                    self.api
                )
            )
Z
zyfncg 已提交
244

245
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
246

247
    def gene_invoke_code(self, invoke_code, params_code):
248 249
        invoke_func_name = invoke_code.split('(')[0].strip()
        if invoke_func_name.endswith('_grad') or invoke_func_name.endswith(
250 251
            '_impl'
        ):
252
            return f"""
253
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
254 255 256 257
  {invoke_code};
}}"""

        else:
258
            return ""
259

260 261 262 263 264

def header_include():
    return """
#include <tuple>

265 266
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
267
#include "paddle/phi/common/int_array.h"
268
#include "paddle/utils/optional.h"
269 270 271 272 273 274 275 276 277 278
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

279
#include "paddle/phi/api/lib/api_custom_impl.h"
280
#include "paddle/phi/api/lib/api_gen_utils.h"
281 282
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
283
#include "paddle/phi/common/type_traits.h"
284 285 286
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h"
287
#include "paddle/phi/infermeta/unary.h"
288 289

#include "paddle/fluid/platform/profiler/event_tracing.h"
290
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
Z
zyfncg 已提交
291 292

DECLARE_bool(conv2d_disable_cudnn);
293 294 295 296
"""


def backward_api_namespace():
297 298
    return (
        """
299 300 301
namespace paddle {
namespace experimental {

302 303
""",
        """
304 305 306

}  // namespace experimental
}  // namespace paddle
307 308
""",
    )
309 310


311 312 313
def generate_backward_api(
    backward_yaml_path, header_file_path, source_file_path
):
314

315 316 317 318 319 320 321
    bw_apis = []
    for each_api_yaml in backward_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                bw_apis.extend(api_list)

322 323 324 325 326 327 328 329 330
    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = backward_api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

331
    include_header_file = "paddle/phi/api/backward/backward_api.h"
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for bw_api in bw_apis:
        bw_api = BackwardAPI(bw_api)
        header_file.write(bw_api.gene_api_declaration())
        source_file.write(bw_api.gene_api_code())

    header_file.write(namespace[1])
    source_file.write(namespace[1])

    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
349 350 351 352 353 354
        description='Generate PaddlePaddle C++ backward API files'
    )
    parser.add_argument(
        '--backward_yaml_path',
        help='path to backward yaml file',
        nargs='+',
355
        default=['paddle/phi/api/yaml/backward.yaml'],
356 357 358 359 360 361 362 363 364 365 366 367
    )
    parser.add_argument(
        '--backward_header_path',
        help='output of generated backward header code file',
        default='paddle/phi/api/backward/backward_api.h',
    )

    parser.add_argument(
        '--backward_source_path',
        help='output of generated backward source code file',
        default='paddle/phi/api/lib/backward_api.cc',
    )
368 369 370 371 372 373 374

    options = parser.parse_args()

    backward_yaml_path = options.backward_yaml_path
    header_file_path = options.backward_header_path
    source_file_path = options.backward_source_path

375 376 377
    generate_backward_api(
        backward_yaml_path, header_file_path, source_file_path
    )
378 379 380 381


if __name__ == '__main__':
    main()