python_c_gen.py 19.3 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
17
from codegen_utils import FunctionGeneratorBase, GeneratorBase
18 19
from codegen_utils import GetForwardFunctionName, IsVectorTensorType
from codegen_utils import GetInplacedFunctionName
20

21 22 23
###########################
## Global Configurations ##
###########################
24
skipped_forward_api_names = set([])
25 26 27 28 29


def SkipAPIGeneration(forward_api_name):
    return (forward_api_name in skipped_forward_api_names)

30

31 32 33 34
atype_to_parsing_function = {
    "bool": "CastPyArg2Boolean",
    "int": "CastPyArg2Int",
    "long": "CastPyArg2Long",
35
    "int64_t": "CastPyArg2Long",
36
    "float": "CastPyArg2Float",
37
    "double": "CastPyArg2Double",
F
From00 已提交
38
    "std::string": "CastPyArg2String",
39 40 41 42 43 44 45 46
    "std::vector<bool>": "CastPyArg2Booleans",
    "std::vector<int>": "CastPyArg2Ints",
    "std::vector<long>": "CastPyArg2Longs",
    "std::vector<int64_t>": "CastPyArg2Longs",
    "std::vector<float>": "CastPyArg2Floats",
    "std::vector<double>": "CastPyArg2Float64s",
    "std::vector<std::string>": "CastPyArg2Strings",
    "paddle::experimental::Scalar": "CastPyArg2Scalar",
47
    "std::vector<phi::Scalar>": "CastPyArg2ScalarArray",
48
    "paddle::experimental::IntArray": "CastPyArg2IntArray",
49
    "paddle::Place": "CastPyArg2Place",
50
    "paddle::experimental::DataType": "CastPyArg2DataType",
51 52 53 54 55
}


def FindParsingFunctionFromAttributeType(atype):
    if atype not in atype_to_parsing_function.keys():
56
        assert False, f"Unable to find {atype} in atype_to_parsing_function."
57 58 59 60

    return atype_to_parsing_function[atype]


61 62 63 64
##########################
## Refactored Functions ##
##########################
PARSE_PYTHON_C_TENSORS_TEMPLATE = \
H
hong 已提交
65
"    auto {} = {}(\"{}\", \"{}\", args, {}, {});\n"
66

67

68
PARSE_PYTHON_C_ARGS_TEMPLATE = \
69 70 71
"""    PyObject* {}_obj = PyTuple_GET_ITEM(args, {});
    {} {} = {}({}_obj, \"{}\", {});
"""
72 73


74
RECORD_EVENT_TEMPLATE = \
75
"paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::UserDefined, 1);"
76

77

78 79
RETURN_INPLACE_PYOBJECT_TEMPLATE = \
"""
80
    inplace_var_idx_map[{}] = {};
81 82 83
"""


84 85
PYTHON_C_FUNCTION_TEMPLATE = \
"""
86
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
87
  {}
88
  PyThreadState *tstate = nullptr;
89
  try {{
90
    VLOG(6) << "Running Eager Final State API: {}";
91 92
    // Get EagerTensors from args
{}
93
    // Parse Attributes if needed
94 95
{}
    tstate = PyEval_SaveThread();
96 97 98

    // Set Device ID
{}
99
    // Call dygraph function
100
    {}
101

102 103
    PyEval_RestoreThread(tstate);
    tstate = nullptr;
104
{}
105
  }} catch(...) {{
106 107 108 109 110 111 112 113
    if (tstate) {{
      PyEval_RestoreThread(tstate);
    }}
    ThrowExceptionToPython(std::current_exception());
    return nullptr;
  }}
}}
"""
114

W
wanghuancoder 已提交
115
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});"
116

Z
zyfncg 已提交
117

118
FUNCTION_SET_DEVICE_TEMPLATE = \
119
"""{}    if (paddle::platform::is_gpu_place(place)) {{
120 121 122 123 124 125
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      phi::backends::gpu::SetDeviceId(place.device);
      VLOG(1) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device;
#else
      PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with GPU if use CUDAPlace."));
126 127 128 129 130 131 132 133 134
#endif
    }}
    if (paddle::platform::is_custom_place(place)) {{
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
      phi::DeviceManager::SetDevice(place);
      VLOG(1) <<"CurrentDeviceId: " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << (int)place.device;
#else
      PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
135 136 137
#endif
    }}
"""
138

139 140
FUNCTION_NAME_TEMPLATE = \
"{}{}{}"
141 142


143
PYTHON_C_FUNCTION_REG_TEMPLATE = \
W
wanghuancoder 已提交
144
"  {{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}},\n"
145 146


147 148
PYTHON_C_WRAPPER_TEMPLATE = \
"""
149 150 151
#include <Python.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/strings_api.h"
152
#include "paddle/phi/backends/device_manager.h"
153 154 155 156 157
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
158
#include "paddle/fluid/pybind/eager_custom_python_api.h"
159
#include "paddle/fluid/pybind/eager.h"
160 161
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
162 163 164 165 166 167 168

namespace paddle {{
namespace pybind {{

{}

static PyMethodDef EagerFinalStateMethods[] = {{
W
wanghuancoder 已提交
169
{}
170 171
}};

172 173 174 175 176 177 178 179 180 181
void BindFinalStateEagerOpFunctions(pybind11::module *module) {{
  if (PyModule_AddFunctions(module->ptr(), EagerFinalStateMethods) < 0) {{
    PADDLE_THROW(platform::errors::Fatal ("Add functions to core.eager.ops failed!"));
  }}

  if (PyModule_AddFunctions(module->ptr(), CustomEagerFinalStateMethods) < 0) {{
    PADDLE_THROW(platform::errors::Fatal ("Add functions to core.eager.ops failed!"));
  }}
}}

182 183 184 185 186 187 188
}} // namespace pybind
}} // namespace paddle
"""


CORE_OPS_INFO = \
"""
189
static PyObject * eager_get_core_ops_args_info(PyObject *self) {
190
    PyThreadState *tstate = nullptr;
W
wanghuancoder 已提交
191
    try {
192
      return ToPyObject(core_ops_args_info);
193 194 195 196 197 198 199 200 201 202
    }
    catch(...) {
      if (tstate) {
        PyEval_RestoreThread(tstate);
      }
      ThrowExceptionToPython(std::current_exception());
      return nullptr;
    }
}

203
static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
204
    PyThreadState *tstate = nullptr;
W
wanghuancoder 已提交
205
    try {
206
      return ToPyObject(core_ops_args_type_info);
207 208 209 210 211 212 213 214 215 216
    }
    catch(...) {
      if (tstate) {
        PyEval_RestoreThread(tstate);
      }
      ThrowExceptionToPython(std::current_exception());
      return nullptr;
    }
}

217
static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
218
    PyThreadState *tstate = nullptr;
W
wanghuancoder 已提交
219
    try {
220
      return ToPyObject(core_ops_returns_info);
221 222 223 224 225 226 227 228 229
    }
    catch(...) {
      if (tstate) {
        PyEval_RestoreThread(tstate);
      }
      ThrowExceptionToPython(std::current_exception());
      return nullptr;
    }
}
230 231
"""

232

233 234
CORE_OPS_INFO_REGISTRY = \
"""
W
wanghuancoder 已提交
235 236 237
  {\"get_core_ops_args_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_info.\"},
  {\"get_core_ops_args_type_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_type_info.\"},
  {\"get_core_ops_returns_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
238 239
"""

240 241 242 243 244 245 246 247 248 249
NAMESPACE_WRAPPER_TEMPLATE = \
"""namespace {} {{
    {}
}}
"""


#######################
## Generator Classes ##
#######################
250
class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
251

252 253 254 255 256 257 258 259 260 261 262 263
    def __init__(self, forward_api_contents, namespace):
        # Members from Parent:
        #self.namespace
        #self.forward_api_contents
        #self.forward_api_name
        #self.orig_forward_inputs_list
        #self.orig_forward_attrs_list
        #self.orig_forward_returns_list
        #self.forward_inputs_position_map
        #self.forward_outputs_position_map
        #self.optional_inputs
        #self.no_need_buffers
264
        #self.intermediate_outputs
265
        #self.forward_inplace_map
266 267
        FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)

268 269 270 271 272 273 274
        self.is_forward_only = True

        # Generated Results
        self.python_c_function_str = ""
        self.python_c_function_reg_str = ""

    def CollectIsForwardOnly(self):
275 276
        forward_api_contents = self.forward_api_contents
        self.is_forward_only = False if 'backward' in forward_api_contents.keys(
277 278
        ) else True

279
    def GeneratePythonCFunction(self):
280
        namespace = self.namespace
281
        forward_inplace_map = self.forward_inplace_map
282 283
        forward_api_name = self.forward_api_name
        orig_forward_attrs_list = self.orig_forward_attrs_list
284 285 286 287 288
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        optional_inputs = self.optional_inputs
        is_forward_only = self.is_forward_only

289 290
        inplace_args_pos_map = {}
        inplace_returns_pos_map = {}
291 292 293
        # Generate Python-C Tensors Parsing Logic
        get_eager_tensor_str = ""
        for name, (ttype, pos) in forward_inputs_position_map.items():
294 295
            if forward_inplace_map and name in forward_inplace_map.keys():
                inplace_args_pos_map[name] = pos
296 297
            is_optional = (name in optional_inputs)
            if IsVectorTensorType(ttype):
298 299 300 301 302 303 304 305
                if is_optional:
                    get_eager_tensor_str += PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
                        name, "GetOptionalTensorListFromArgs", forward_api_name,
                        name, pos, "true")
                else:
                    get_eager_tensor_str += PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
                        name, "GetTensorListFromArgs", forward_api_name, name,
                        pos, "false")
306 307 308 309
            else:
                if is_optional:
                    get_eager_tensor_str += PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
                        name, "GetOptionalTensorFromArgs", forward_api_name,
H
hong 已提交
310
                        name, pos, "true")
311 312
                else:
                    get_eager_tensor_str += PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
H
hong 已提交
313 314
                        name, "GetTensorFromArgs", forward_api_name, name, pos,
                        "false")
315

316 317 318 319 320
        if forward_inplace_map:
            for name, (ttype, pos) in forward_outputs_position_map.items():
                if name in forward_inplace_map.values():
                    inplace_returns_pos_map[name] = pos

321
        parse_attributes_str = ""
322
        expected_place_str = "    auto place = egr::Controller::Instance().GetExpectedPlace();\n"
323 324

        # Generate Python-C Attributes Parsing Logic
325
        for name, atype, _, pos in orig_forward_attrs_list:
326
            parsing_function_name = FindParsingFunctionFromAttributeType(atype)
327 328 329 330 331 332
            # Used input argument place if specified from Python frontend.
            if len(expected_place_str
                   ) != 0 and parsing_function_name == "CastPyArg2Place":
                expected_place_str = ""
                assert name == "place", "Only support 'place' as template argument name in FUNCTION_SET_DEVICE_TEMPLATE."

333 334 335 336
            parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format(
                name, pos, atype, name, parsing_function_name, name,
                forward_api_name, pos)

337 338
        set_device_str = FUNCTION_SET_DEVICE_TEMPLATE.format(expected_place_str)

339
        # Generate Dygraph Function Call Logic
340 341
        num_args = len(
            forward_inputs_position_map.keys()) + len(orig_forward_attrs_list)
342 343 344
        dygraph_function_call_list = ["" for i in range(num_args)]
        for name, (_, pos) in forward_inputs_position_map.items():
            dygraph_function_call_list[pos] = f"{name}"
345
        for name, _, _, pos in orig_forward_attrs_list:
346 347 348
            dygraph_function_call_list[pos] = f"{name}"
        dygraph_function_call_str = ",".join(dygraph_function_call_list)

349
        # Generate Python-C Function Definitions
350 351
        fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
            "::", namespace, GetForwardFunctionName(forward_api_name))
352

353
        return_str = "    return ToPyObject(out);"
354

355 356 357
        # Generate Record Event for performance profiling
        pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
            "pythonc_record_event", forward_api_name, "pybind_imperative_func")
358

359
        noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
360
            fwd_function_name, dygraph_function_call_str, fwd_function_name,
361 362 363
            dygraph_function_call_str)

        # Generate Python-C Function Definetion
364 365 366 367
        self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
            forward_api_name, pythonc_record_event_str, forward_api_name,
            get_eager_tensor_str, parse_attributes_str, set_device_str,
            noamp_dygraph_function_str, return_str)
368

369 370 371
        # Set prefix of forward_api_name to avoid conflicts
        prefix = self.namespace.strip("::")
        forward_api_name_prefix = "" if prefix == "" else prefix + "_"
372

373 374
        # Generate Python-C Function Registration
        self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
375 376
            forward_api_name_prefix, forward_api_name, namespace,
            forward_api_name, forward_api_name)
377

378
        if forward_inplace_map:
379 380
            inplaced_forward_api_name = GetInplacedFunctionName(
                self.forward_api_name)
381 382 383
            inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
                "::", namespace,
                GetForwardFunctionName(inplaced_forward_api_name))
384

385 386 387 388
            inplace_noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
                inplaced_fwd_function_name, dygraph_function_call_str,
                inplaced_fwd_function_name, dygraph_function_call_str)

389
            return_str = "    std::map<ssize_t, ssize_t> inplace_var_idx_map;"
390
            for inplace_input, inplace_output in forward_inplace_map.items():
391 392 393 394
                return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
                    inplace_returns_pos_map[inplace_output],
                    inplace_args_pos_map[inplace_input])
            return_str += "    return ToPyObject(out, args, inplace_var_idx_map);"
395

396
            # Generate Python-C Function Definetion
397 398 399 400 401
            python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
                inplaced_forward_api_name, pythonc_record_event_str,
                inplaced_forward_api_name, get_eager_tensor_str,
                parse_attributes_str, set_device_str,
                inplace_noamp_dygraph_function_str, return_str)
402

Z
zyfncg 已提交
403
            python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
404 405
                forward_api_name_prefix, inplaced_forward_api_name, namespace,
                inplaced_forward_api_name, inplaced_forward_api_name)
406

Z
zyfncg 已提交
407 408 409 410 411 412 413 414
            # self.forward_api_name ending with '_' means it only has inplace api
            if self.forward_api_name[-1] == '_':
                self.python_c_function_str = python_c_inplace_func_str
                # Generate Python-C Function Registration
                self.python_c_function_reg_str = python_c_inplace_func_reg_str
            else:
                self.python_c_function_str += python_c_inplace_func_str
                # Generate Python-C Function Registration
W
wanghuancoder 已提交
415
                self.python_c_function_reg_str += python_c_inplace_func_reg_str
Z
zyfncg 已提交
416

417
    def run(self):
418 419 420 421
        # Initialized is_forward_only
        self.CollectIsForwardOnly()

        # Initialized optional_inputs
422 423
        self.ParseDispensable()

424 425
        # Initialized forward_inplace_map
        self.ParseForwardInplaceInfo()
426

427 428
        # Initialized orig_forward_inputs_list, orig_forward_returns_list, orig_forward_attrs_list
        self.CollectOriginalForwardInfo()
429

430 431
        if SkipAPIGeneration(self.forward_api_name): return False

432
        # Initialized forward_inputs_position_map, forward_outputs_position_map
433 434
        self.DetermineForwardPositionMap(self.orig_forward_inputs_list,
                                         self.orig_forward_returns_list)
435 436

        # Code Generation
437
        self.GeneratePythonCFunction()
438 439 440 441

        return True


442
class PythonCGenerator(GeneratorBase):
443

444
    def __init__(self, path):
445
        # Parent members:
446 447 448
        # self.namespace
        # self.api_yaml_path
        # self.forward_api_list
449
        GeneratorBase.__init__(self, api_yaml_path)
450 451 452

        # Generated Result
        self.python_c_functions_str = ""
453
        self.python_c_functions_reg_str = ""
454 455 456 457

    def GeneratePythonCFunctions(self):
        namespace = self.namespace

458
        forward_api_list = self.forward_api_list
459
        for forward_api_content in forward_api_list:
460 461
            f_generator = PythonCSingleFunctionGenerator(
                forward_api_content, namespace)
462
            status = f_generator.run()
463 464 465

            if status == True:
                self.python_c_functions_str += f_generator.python_c_function_str + "\n"
W
wanghuancoder 已提交
466
                self.python_c_functions_reg_str += f_generator.python_c_function_reg_str
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

    def AttachNamespace(self):
        namespace = self.namespace
        python_c_functions_str = self.python_c_functions_str

        if namespace != "":
            if namespace.endswith("::"):
                namespace = namespace[:-2]
            self.python_c_functions_str = NAMESPACE_WRAPPER_TEMPLATE.format(
                namespace, python_c_functions_str)

    def run(self):
        # Infer namespace from yaml_path
        self.InferNameSpace()

        # Read Yaml file
483
        self.ParseForwardYamlContents()
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506

        # Code Generation
        self.GeneratePythonCFunctions()

        # Wrap with namespace
        self.AttachNamespace()


############################
## Code Generation Helper ##
############################
def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Eager Code Generator Args Parser')
    parser.add_argument('--api_yaml_path', type=str)
    parser.add_argument('--output_path', type=str)

    args = parser.parse_args()
    return args


def GenerateCoreOpsInfoMap():
    return CORE_OPS_INFO, CORE_OPS_INFO_REGISTRY
507 508


509 510
def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):

511 512 513 514 515
    core_ops_infos_definition, core_ops_infos_registry = GenerateCoreOpsInfoMap(
    )

    python_c_function_str += core_ops_infos_definition
    python_c_function_reg_str += core_ops_infos_registry
W
wanghuancoder 已提交
516
    python_c_function_reg_str += "  {nullptr,nullptr,0,nullptr}"
517

518 519 520 521 522 523 524 525 526 527 528 529 530
    python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
                                                    python_c_function_reg_str)

    return python_c_str


def GeneratePythonCFile(filepath, python_c_str):
    with open(filepath, 'a') as f:
        f.write(python_c_str)


if __name__ == "__main__":
    args = ParseArguments()
531 532
    api_yaml_paths = args.api_yaml_path.split(",")

533 534
    generated_python_c_functions = ""
    generated_python_c_registration = ""
535 536 537
    for i in range(len(api_yaml_paths)):
        api_yaml_path = api_yaml_paths[i]

538 539
        py_c_generator = PythonCGenerator(api_yaml_path)
        py_c_generator.run()
540

541
        generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
W
wanghuancoder 已提交
542
        generated_python_c_registration += py_c_generator.python_c_functions_reg_str
543

544 545
    python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
                                           generated_python_c_registration)
546

547 548 549 550 551 552
    output_path = args.output_path
    for path in [output_path]:
        if os.path.exists(path):
            os.remove(path)

    GeneratePythonCFile(output_path, python_c_str)