backward_api_gen.py 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import argparse
import re

20
from api_base import BaseAPI
21 22


23
class BackwardAPI(BaseAPI):
24

25
    def __init__(self, backward_item_yaml):
26 27
        super(BackwardAPI, self).__init__(backward_item_yaml)
        self.check_args(backward_item_yaml['forward'])
28
        self.no_need_buffer = self.parse_no_need_buffer(backward_item_yaml)
29 30 31

    def get_api_name(self, api_item_yaml):
        return api_item_yaml['backward_api']
32

33 34 35
    def parse_forward_config(self, forward_config):
        # api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
        result = re.search(
36
            r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
37 38
            forward_config)
        api = result.group('api')
39
        _, outputs, _, = self.parse_output(self.api, result.group('outputs'))
40
        outputs = [item.split('@')[0] for item in outputs]
41 42
        fw_inputs, fw_attrs = self.parse_input_and_attr(api,
                                                        result.group('args'))
43

44
        return api, fw_inputs, fw_attrs, outputs
45

46 47 48 49 50 51 52 53 54
    def parse_no_need_buffer(self, api_item_yaml):
        no_need_buffer = []
        if 'no_need_buffer' in api_item_yaml:
            no_need_buffer = [
                item.strip()
                for item in api_item_yaml['no_need_buffer'].split(',')
            ]
        return no_need_buffer

55
    def check_args(self, forward_config):
56 57 58 59 60
        # parse the forward and backward config
        _, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config(
            forward_config)

        # check the inputs of backward
61
        for input in self.inputs['names']:
62
            if input not in fw_inputs['names'] and input not in fw_outputs:
63 64 65
                if input.endswith('_grad'):
                    original_name = input[:-5]
                    assert original_name in fw_outputs, \
66 67
                        f"{self.api} : Input Tensor error: the input tensor({input}) of backward should be an input or output or grad of output in forward api. \
                         Please check the forward of {self.api} in yaml."
68 69

        # check the attributes of backward
70
        for attr in self.attrs['names']:
71 72 73
            assert (attr in fw_attrs['names'] and self.attrs['attr_info'][attr][0] == fw_attrs['attr_info'][attr][0]) or \
                 self.attrs['attr_info'][attr][1] is not None, \
                f"{self.api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api or doesn't have default value. \
74
                 Please check the args of {self.api} in yaml."
75 76

        # check the output of backward
77
        assert len(self.outputs['types']) <= len(fw_inputs['names']), \
78
            f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
79
             Please check the output of {self.api} in yaml."
80

81
    def get_declare_args(self, inplace_flag=False):
82 83
        return self.get_define_args()

84
    def get_define_args(self, inplace_flag=False):
85 86 87 88
        out_type_map = {
            'Tensor': 'Tensor*',
            'std::vector<Tensor>': 'std::vector<Tensor*>'
        }
89
        intputs_and_attrs = super(BackwardAPI, self).get_define_args()
90 91 92 93 94 95 96 97 98 99
        outs = []
        for i, name in enumerate(self.outputs['names']):
            outs.append(out_type_map[self.outputs['types'][i]] + ' ' +
                        name.split('@')[0])
        result = intputs_and_attrs + ', ' + ", ".join(outs)
        return result

    def gene_return_code(self):
        return ""

100 101 102 103 104 105 106 107 108 109 110 111 112
    def gene_kernel_backend_select(self):
        all_no_need_buffer = True
        for in_name in self.inputs['names']:
            if in_name not in self.no_need_buffer:
                all_no_need_buffer = False

        if all_no_need_buffer:
            return """
  kernel_backend = ParseBackend(egr::Controller::Instance().GetExpectedPlace());
"""
        else:
            return super().gene_kernel_backend_select()

113
    def get_return_type(self, inplace_flag=False):
114
        return 'void'
115

116
    def gene_output(self,
117 118 119
                    out_dtype_list,
                    out_tensor_type_list=None,
                    code_indent='',
120
                    inplace_flag=False):
121
        kernel_output = []
122
        output_names = []
Z
zyfncg 已提交
123 124
        output_create = ""

125
        if len(out_dtype_list) == 1:
126
            kernel_output.append('kernel_out')
127
            output_names.append('kernel_out')
128 129 130
            inplace_assign = " = " + self.inplace_map[self.outputs['names'][
                0]] if inplace_flag and self.inplace_map is not None and self.outputs[
                    'names'][0] in self.inplace_map else ""
131
            output_create = ""
132 133 134
            set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
                0] == 'dense' else 'SetSelectedRowsKernelOutput'
            if out_dtype_list[0] == 'std::vector<Tensor>':
135
                assert self.outputs['out_size_expr'] is not None, \
136
                     f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
137
                output_create = output_create + f"""
138
{code_indent}  auto kernel_out = {set_out_func}(&{self.outputs['names'][0]});"""
139 140 141

            else:
                output_create = output_create + f"""
142
{code_indent}  auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});"""
Z
zyfncg 已提交
143

144
        elif len(out_dtype_list) > 1:
145
            output_create = ""
146
            for i, out_type_item in enumerate(out_dtype_list):
147
                kernel_output.append(f'kernel_out_{i}')
148
                output_names.append(f'kernel_out_{i}')
149 150
                set_out_func = 'SetKernelOutput' if out_tensor_type_list is None or out_tensor_type_list[
                    i] == 'dense' else 'SetSelectedRowsKernelOutput'
151
                if out_type_item == 'Tensor':
152 153 154
                    if inplace_flag and self.inplace_map is not None and self.outputs[
                            'names'][i] in self.inplace_map:
                        output_create = output_create + f"""
155
{code_indent}  *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
156

157
                    output_create = output_create + f"""
158
{code_indent}  auto kernel_out_{i} = {set_out_func}(kernel_backend, {self.outputs['names'][i]});"""
159

160
                else:
161 162 163
                    if inplace_flag and self.inplace_map is not None and self.outputs[
                            'names'][i] in self.inplace_map:
                        output_create = output_create + f"""
164
{code_indent}  *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
165

166
                    assert self.outputs['out_size_expr'][i] is not None, \
167
                        f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
168
                    output_create = output_create + f"""
169
{code_indent}  auto kernel_out_{i} = {set_out_func}(&{self.outputs['names'][i]});"""
Z
zyfncg 已提交
170 171 172 173

        else:
            raise ValueError(
                "{} : Output error: the output should not be empty.".format(
174
                    self.api))
Z
zyfncg 已提交
175

176
        return kernel_output, output_names, output_create
Z
zyfncg 已提交
177

178
    def gene_invoke_code(self, invoke_code, params_code):
179 180
        invoke_func_name = invoke_code.split('(')[0].strip()
        if invoke_func_name.endswith('_grad') or invoke_func_name.endswith(
181 182
                '_grad_impl'):
            return f"""
183
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
184 185 186 187 188
  {invoke_code};
}}"""

        else:
            return f"""
189
PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
190 191 192
  *{self.outputs['names'][0].split('@')[0]} = {invoke_code};
}}"""

193 194 195 196 197

def header_include():
    return """
#include <tuple>

198 199
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
200
#include "paddle/phi/common/int_array.h"
201
#include "paddle/utils/optional.h"
202 203 204 205 206 207 208 209 210 211
"""


def source_include(header_file_path):
    return f"""
#include "{header_file_path}"
#include <memory>

#include "glog/logging.h"

212
#include "paddle/phi/api/lib/api_custom_impl.h"
213
#include "paddle/phi/api/lib/api_gen_utils.h"
214 215 216 217 218
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h"
219
#include "paddle/phi/infermeta/unary.h"
220

221
#include "paddle/fluid/eager/api/utils/global_utils.h"
222
#include "paddle/fluid/platform/profiler/event_tracing.h"
223
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
Z
zyfncg 已提交
224 225

DECLARE_bool(conv2d_disable_cudnn);
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
"""


def backward_api_namespace():
    return ("""
namespace paddle {
namespace experimental {

""", """

}  // namespace experimental
}  // namespace paddle
""")


def generate_backward_api(backward_yaml_path, header_file_path,
                          source_file_path):

244 245 246 247 248 249 250
    bw_apis = []
    for each_api_yaml in backward_yaml_path:
        with open(each_api_yaml, 'r') as f:
            api_list = yaml.load(f, Loader=yaml.FullLoader)
            if api_list:
                bw_apis.extend(api_list)

251 252 253 254 255 256 257 258 259
    header_file = open(header_file_path, 'w')
    source_file = open(source_file_path, 'w')

    namespace = backward_api_namespace()

    header_file.write("#pragma once\n")
    header_file.write(header_include())
    header_file.write(namespace[0])

260
    include_header_file = "paddle/phi/api/backward/backward_api.h"
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    source_file.write(source_include(include_header_file))
    source_file.write(namespace[0])

    for bw_api in bw_apis:
        bw_api = BackwardAPI(bw_api)
        header_file.write(bw_api.gene_api_declaration())
        source_file.write(bw_api.gene_api_code())

    header_file.write(namespace[1])
    source_file.write(namespace[1])

    header_file.close()
    source_file.close()


def main():
    parser = argparse.ArgumentParser(
        description='Generate PaddlePaddle C++ backward API files')
279 280 281
    parser.add_argument('--backward_yaml_path',
                        help='path to backward yaml file',
                        nargs='+',
282
                        default='paddle/phi/api/yaml/backward.yaml')
283 284 285 286 287 288 289
    parser.add_argument('--backward_header_path',
                        help='output of generated backward header code file',
                        default='paddle/phi/api/backward/backward_api.h')

    parser.add_argument('--backward_source_path',
                        help='output of generated backward source code file',
                        default='paddle/phi/api/lib/backward_api.cc')
290 291 292 293 294 295 296 297 298 299 300 301 302

    options = parser.parse_args()

    backward_yaml_path = options.backward_yaml_path
    header_file_path = options.backward_header_path
    source_file_path = options.backward_source_path

    generate_backward_api(backward_yaml_path, header_file_path,
                          source_file_path)


if __name__ == '__main__':
    main()