未验证 提交 43f84d0f 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added python-c code generation for final state Eager Dygraph (#39233)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Fixed issues from merge

* Fixed merge issues
上级 f7a3389e
...@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen ...@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path}
VERBATIM VERBATIM
) )
set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h")
set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h")
add_custom_target(eager_final_state_python_c_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--output_path=${tmp_python_c_output_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path} ${python_c_output_path}
VERBATIM
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
"int": "CastPyArg2Int",
"long": "CastPyArg2Long",
"float": "CastPyArg2Float",
"string": "CastPyArg2String",
"bool[]": "CastPyArg2Booleans",
"int[]": "CastPyArg2Ints",
"long[]": "CastPyArg2Longs",
"float[]": "CastPyArg2Floats",
"double[]": "CastPyArg2Float64s",
"string[]": "CastPyArg2Strings"
}
atype_to_cxx_type = {
"bool": "bool",
"int": "int",
"long": "long",
"float": "float",
"string": "std::string",
"bool[]": "std::vector<bool>",
"int[]": "std::vector<int>",
"long[]": "std::vector<long>",
"float[]": "std::vector<float>",
"double[]": "std::vector<double>",
"string[]": "std::vector<std::string>"
}
def ParseArguments():
parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser')
parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--output_path', type=str)
args = parser.parse_args()
return args
def GetCxxType(atype):
if atype not in atype_to_cxx_type.keys():
assert False
return atype_to_cxx_type[atype]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
assert False
return atype_to_parsing_function[atype]
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_input_tensors = len(forward_inputs_position_map.keys())
dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = " paddle::framework::AttributeMap attrs;\n"
# Get Attributes
for name, atype, _, pos in forward_attrs_list:
parsing_function = FindParsingFunctionFromAttributeType(atype)
cxx_type = GetCxxType(atype)
key = f"{name}"
parse_attributes_str += f" PyObject* {name}_obj = PyTuple_GET_ITEM(args, {pos});\n"
parse_attributes_str += f" {cxx_type} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n"
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
PYTHON_C_FUNCTION_TEMPLATE = """
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
return python_c_function_str, python_c_function_reg_str
def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str)
return python_c_str
def GeneratePythonCFile(filepath, python_c_str):
with open(filepath, 'a') as f:
f.write(python_c_str)
if __name__ == "__main__":
args = ParseArguments()
api_yaml_path = args.api_yaml_path
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ", forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
python_c_function_reg_list.append("{nullptr,nullptr,0,nullptr}")
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)
python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)
print("Generated Python-C Codes: ", python_c_str)
output_path = args.output_path
for path in [output_path]:
if os.path.exists(path):
os.remove(path)
GeneratePythonCFile(output_path, python_c_str)
...@@ -152,7 +152,7 @@ if(WITH_PYTHON) ...@@ -152,7 +152,7 @@ if(WITH_PYTHON)
set(tmp_eager_impl_file ${eager_impl_file}.tmp) set(tmp_eager_impl_file ${eager_impl_file}.tmp)
set(OP_IMPL_DEPS op_function_generator) set(OP_IMPL_DEPS op_function_generator)
set(EAGER_OP_IMPL_DEPS eager_op_function_generator) set(EAGER_OP_IMPL_DEPS eager_op_function_generator eager_final_state_python_c_codegen)
if(WIN32) if(WIN32)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja") if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
...@@ -276,7 +276,7 @@ if(WITH_PYTHON) ...@@ -276,7 +276,7 @@ if(WITH_PYTHON)
if(NOT ON_INFER) if(NOT ON_INFER)
cc_library(paddle_eager cc_library(paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python) DEPS eager_api autograd_meta backward grad_node_info pten op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python)
add_dependencies(paddle_eager eager_codegen) add_dependencies(paddle_eager eager_codegen)
add_dependencies(paddle_eager eager_op_function_generator_cmd) add_dependencies(paddle_eager eager_op_function_generator_cmd)
list(APPEND PYBIND_DEPS paddle_eager) list(APPEND PYBIND_DEPS paddle_eager)
......
...@@ -396,6 +396,7 @@ int main(int argc, char* argv[]) { ...@@ -396,6 +396,7 @@ int main(int argc, char* argv[]) {
std::vector<std::string> headers{ std::vector<std::string> headers{
"\"pybind11/detail/common.h\"", "\"pybind11/detail/common.h\"",
"\"paddle/fluid/pybind/eager_final_state_op_function_impl.h\"",
"\"paddle/fluid/pybind/op_function_common.h\"", "\"paddle/fluid/pybind/op_function_common.h\"",
"\"paddle/fluid/eager/api/generated/fluid_generated/" "\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"", "dygraph_forward_api.h\"",
...@@ -444,6 +445,10 @@ int main(int argc, char* argv[]) { ...@@ -444,6 +445,10 @@ int main(int argc, char* argv[]) {
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n" "core.eager.ops failed!\"));\n"
<< " }\n\n" << " }\n\n"
<< " if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< "}\n\n" << "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册