未验证 提交 56eedf27 编写于 作者: W wanghuancoder 提交者: GitHub

refine eager code gen (#45540)

上级 32f42e94
...@@ -1054,7 +1054,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1054,7 +1054,7 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable, // If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out = // then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")" // egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta \n"; std::string get_input_autograd_meta_str = " // Prepare Autograd Meta\n";
std::string get_output_autograd_meta_str = ""; std::string get_output_autograd_meta_str = "";
// If single output slotname and not duplicable, // If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out = // then generate: "egr::AutogradMeta* p_autograd_out =
...@@ -1390,7 +1390,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1390,7 +1390,7 @@ static std::string GenerateGradNodeCreationContent(
"paddle::platform::TracerEventType::OperatorInner, 1);\n" "paddle::platform::TracerEventType::OperatorInner, 1);\n"
"%s" "%s"
" if(require_any_grad) {\n" " if(require_any_grad) {\n"
" VLOG(6) << \" Construct Grad for %s \"; \n" " VLOG(6) << \" Construct Grad for %s \";\n"
" egr::EagerUtils::PassStopGradient(%s);\n" " egr::EagerUtils::PassStopGradient(%s);\n"
" %s\n" " %s\n"
" }\n" " }\n"
...@@ -1750,7 +1750,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1750,7 +1750,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" if (egr::Controller::Instance().GetAMPLevel() != " " if (egr::Controller::Instance().GetAMPLevel() != "
"paddle::imperative::AmpLevel::O0) {\n" "paddle::imperative::AmpLevel::O0) {\n"
" VLOG(5) << \"Check and Prepare For AMP\";\n" " VLOG(5) << \"Check and Prepare For AMP\";\n"
" \n" " \n"
"%s\n" "%s\n"
" }\n"; " }\n";
std::string amp_logic_str = ""; std::string amp_logic_str = "";
...@@ -1875,7 +1875,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1875,7 +1875,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" paddle::framework::AttributeMap attrs = attr_map;\n" " paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n" " paddle::framework::AttributeMap default_attrs;\n"
" egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, " " egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, "
"outs, attrs, \n" "outs, attrs,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n" " egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {%s});\n"; " &default_attrs, true, {%s});\n";
std::string trace_op_str = paddle::string::Sprintf( std::string trace_op_str = paddle::string::Sprintf(
...@@ -2152,7 +2152,7 @@ static std::string GenerateSingleOpBase( ...@@ -2152,7 +2152,7 @@ static std::string GenerateSingleOpBase(
size_t fwd_output_position = fwd_outputs_name_pos_map.at( size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name)); grad_ins_grad_slotname_map.at(grad_input_name));
const char* FILL_ZERO_TEMPLATE = const char* FILL_ZERO_TEMPLATE =
"egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], " " egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], "
"this->InputMeta()[%d]);\n"; "this->InputMeta()[%d]);\n";
fill_zero_str += paddle::string::Sprintf( fill_zero_str += paddle::string::Sprintf(
FILL_ZERO_TEMPLATE, fwd_output_position, fwd_output_position); FILL_ZERO_TEMPLATE, fwd_output_position, fwd_output_position);
...@@ -2385,9 +2385,9 @@ static std::string GenerateSingleOpBase( ...@@ -2385,9 +2385,9 @@ static std::string GenerateSingleOpBase(
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if((!out_metas[%d].empty()) && " " if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", " "(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(%s[%d])});} \n "; "egr::EagerUtils::TrySyncToVars(%s[%d])});}\n";
outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grads_position, grads_position,
grads_position, grads_position,
...@@ -2406,7 +2406,7 @@ static std::string GenerateSingleOpBase( ...@@ -2406,7 +2406,7 @@ static std::string GenerateSingleOpBase(
!is_op_base_per_duplicable_input) { !is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if(!out_metas[%d].empty()){ %s.insert({ \"%s\", " " if(!out_metas[%d].empty()){ %s.insert({ \"%s\", "
"egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n "; "egr::EagerUtils::CreateVars(out_metas[%d].size())});}\n";
outs_contents_str += outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position, fwd_input_position,
...@@ -2415,10 +2415,10 @@ static std::string GenerateSingleOpBase( ...@@ -2415,10 +2415,10 @@ static std::string GenerateSingleOpBase(
fwd_input_position); fwd_input_position);
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if((!out_metas[%d].empty()) && " " if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", " "(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"{std::make_shared<egr::EagerVariable>(egr::Controller::Instance(" "{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
").GenerateUniqueName())}});} \n "; ").GenerateUniqueName())}});}\n";
outs_contents_str += outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position, fwd_input_position,
...@@ -2565,7 +2565,7 @@ static std::string GenerateSingleOpBase( ...@@ -2565,7 +2565,7 @@ static std::string GenerateSingleOpBase(
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) { if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" if (%s.find(\"%s\") != %s.end()) { outputs[%d] = " " if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n"; "egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
outs_name, outs_name,
...@@ -2754,7 +2754,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -2754,7 +2754,7 @@ static std::string GenerateGradNodeCCContents(
" const auto& out_metas = OutputMeta();\n" " const auto& out_metas = OutputMeta();\n"
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, " " paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> outputs(%d);\n" "egr::kSlotSmallVectorSize> outputs(%d);\n"
" %s\n" "%s\n"
" if(NeedComplexToRealConversion()) " " if(NeedComplexToRealConversion()) "
"HandleComplexGradToRealGrad(&outputs);\n" "HandleComplexGradToRealGrad(&outputs);\n"
" return outputs;\n"; " return outputs;\n";
...@@ -2813,17 +2813,17 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2813,17 +2813,17 @@ static std::string GenerateGradNodeHeaderContents(
"create_graph = false, bool is_new_grad = false) " "create_graph = false, bool is_new_grad = false) "
"override;\n" "override;\n"
"\n" "\n"
" void ClearTensorWrappers() override { \n" " void ClearTensorWrappers() override {\n"
"%s\n" "%s\n"
" SetIsTensorWrappersCleared(true);\n" " SetIsTensorWrappersCleared(true);\n"
" }\n" " }\n"
" std::string name() override { return \"%sGradNodeCompat\"; } \n " " std::string name() override { return \"%sGradNodeCompat\"; }\n"
"\n" "\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n " "std::shared_ptr<GradNodeBase> Copy() const override {{\n"
" auto copied_node = std::shared_ptr<%sGradNodeCompat>(new " " auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(*this));\n " "%sGradNodeCompat(*this));\n"
" return copied_node;\n " " return copied_node;\n"
"}}\n " "}}\n"
"\n" "\n"
" // SetX, SetY, ...\n" " // SetX, SetY, ...\n"
"%s\n" "%s\n"
...@@ -2838,12 +2838,12 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2838,12 +2838,12 @@ static std::string GenerateGradNodeHeaderContents(
// [Generation] Handle Attributes // [Generation] Handle Attributes
std::string set_attr_map_str = std::string set_attr_map_str =
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n " " void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n "
"attr_map_ = std::move(attr_map);\n }\n"; "attr_map_ = std::move(attr_map);\n }\n";
set_attr_map_str += set_attr_map_str +=
" void SetDefaultAttrMap(paddle::framework::AttributeMap&& " " void SetDefaultAttrMap(paddle::framework::AttributeMap&& "
"default_attr_map) {\n default_attr_map_ = " "default_attr_map) {\n default_attr_map_ = "
"std::move(default_attr_map);\n }\n"; "std::move(default_attr_map);\n }\n";
std::string attr_members_str = std::string attr_members_str =
" paddle::framework::AttributeMap attr_map_;\n"; " paddle::framework::AttributeMap attr_map_;\n";
attr_members_str += " paddle::framework::AttributeMap default_attr_map_;"; attr_members_str += " paddle::framework::AttributeMap default_attr_map_;";
...@@ -2935,7 +2935,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2935,7 +2935,7 @@ static std::string GenerateGradNodeHeaderContents(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name); CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
} }
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s) {\n %s\n }\n"; " void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += set_tensor_wrappers_str +=
paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE, paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE,
tensor_wrapper_name, tensor_wrapper_name,
......
...@@ -87,11 +87,9 @@ PYTHON_C_FUNCTION_TEMPLATE = \ ...@@ -87,11 +87,9 @@ PYTHON_C_FUNCTION_TEMPLATE = \
""" """
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
{} {}
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try {{ try {{
VLOG(6) << "Running Eager Final State API: {}"; VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args // Get EagerTensors from args
{} {}
// Parse Attributes if needed // Parse Attributes if needed
...@@ -116,7 +114,7 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) ...@@ -116,7 +114,7 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
}} }}
""" """
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n" NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});"
FUNCTION_SET_DEVICE_TEMPLATE = \ FUNCTION_SET_DEVICE_TEMPLATE = \
...@@ -145,10 +143,7 @@ FUNCTION_NAME_TEMPLATE = \ ...@@ -145,10 +143,7 @@ FUNCTION_NAME_TEMPLATE = \
PYTHON_C_FUNCTION_REG_TEMPLATE = \ PYTHON_C_FUNCTION_REG_TEMPLATE = \
""" " {{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}},\n"
{{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}
"""
PYTHON_C_WRAPPER_TEMPLATE = \ PYTHON_C_WRAPPER_TEMPLATE = \
...@@ -173,7 +168,7 @@ namespace pybind {{ ...@@ -173,7 +168,7 @@ namespace pybind {{
{} {}
static PyMethodDef EagerFinalStateMethods[] = {{ static PyMethodDef EagerFinalStateMethods[] = {{
{} {}
}}; }};
void BindFinalStateEagerOpFunctions(pybind11::module *module) {{ void BindFinalStateEagerOpFunctions(pybind11::module *module) {{
...@@ -195,8 +190,7 @@ CORE_OPS_INFO = \ ...@@ -195,8 +190,7 @@ CORE_OPS_INFO = \
""" """
static PyObject * eager_get_core_ops_args_info(PyObject *self) { static PyObject * eager_get_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try {
{
return ToPyObject(core_ops_args_info); return ToPyObject(core_ops_args_info);
} }
catch(...) { catch(...) {
...@@ -210,8 +204,7 @@ static PyObject * eager_get_core_ops_args_info(PyObject *self) { ...@@ -210,8 +204,7 @@ static PyObject * eager_get_core_ops_args_info(PyObject *self) {
static PyObject * eager_get_core_ops_args_type_info(PyObject *self) { static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try {
{
return ToPyObject(core_ops_args_type_info); return ToPyObject(core_ops_args_type_info);
} }
catch(...) { catch(...) {
...@@ -225,8 +218,7 @@ static PyObject * eager_get_core_ops_args_type_info(PyObject *self) { ...@@ -225,8 +218,7 @@ static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
static PyObject * eager_get_core_ops_returns_info(PyObject *self) { static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try {
{
return ToPyObject(core_ops_returns_info); return ToPyObject(core_ops_returns_info);
} }
catch(...) { catch(...) {
...@@ -242,16 +234,9 @@ static PyObject * eager_get_core_ops_returns_info(PyObject *self) { ...@@ -242,16 +234,9 @@ static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
CORE_OPS_INFO_REGISTRY = \ CORE_OPS_INFO_REGISTRY = \
""" """
{\"get_core_ops_args_info\", {\"get_core_ops_args_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_info.\"},
(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, {\"get_core_ops_args_type_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_type_info.\"},
\"C++ interface function for eager_get_core_ops_args_info.\"}, {\"get_core_ops_returns_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
{\"get_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_core_ops_args_type_info.\"},
{\"get_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
""" """
NAMESPACE_WRAPPER_TEMPLATE = \ NAMESPACE_WRAPPER_TEMPLATE = \
...@@ -429,7 +414,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -429,7 +414,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
else: else:
self.python_c_function_str += python_c_inplace_func_str self.python_c_function_str += python_c_inplace_func_str
# Generate Python-C Function Registration # Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str self.python_c_function_reg_str += python_c_inplace_func_reg_str
def run(self): def run(self):
# Initialized is_forward_only # Initialized is_forward_only
...@@ -480,7 +465,7 @@ class PythonCGenerator(GeneratorBase): ...@@ -480,7 +465,7 @@ class PythonCGenerator(GeneratorBase):
if status == True: if status == True:
self.python_c_functions_str += f_generator.python_c_function_str + "\n" self.python_c_functions_str += f_generator.python_c_function_str + "\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n" self.python_c_functions_reg_str += f_generator.python_c_function_reg_str
def AttachNamespace(self): def AttachNamespace(self):
namespace = self.namespace namespace = self.namespace
...@@ -530,7 +515,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): ...@@ -530,7 +515,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
python_c_function_str += core_ops_infos_definition python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}" python_c_function_reg_str += " {nullptr,nullptr,0,nullptr}"
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str, python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str) python_c_function_reg_str)
...@@ -556,7 +541,7 @@ if __name__ == "__main__": ...@@ -556,7 +541,7 @@ if __name__ == "__main__":
py_c_generator.run() py_c_generator.run()
generated_python_c_functions += py_c_generator.python_c_functions_str + "\n" generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str + "\n" generated_python_c_registration += py_c_generator.python_c_functions_reg_str
python_c_str = GeneratePythonCWrappers(generated_python_c_functions, python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
generated_python_c_registration) generated_python_c_registration)
......
...@@ -113,8 +113,7 @@ R"( ...@@ -113,8 +113,7 @@ R"(
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs) static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{ {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try {
{
%s %s
framework::AttributeMap attrs; framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs); ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
...@@ -123,8 +122,7 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs) ...@@ -123,8 +122,7 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
tstate = nullptr; tstate = nullptr;
%s %s
} } catch(...) {
catch(...) {
if (tstate) { if (tstate) {
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
} }
...@@ -361,11 +359,9 @@ static std::string GenerateCoreOpsInfoMap() { ...@@ -361,11 +359,9 @@ static std::string GenerateCoreOpsInfoMap() {
std::string result = std::string result =
"static PyObject * eager_get_core_ops_args_info(PyObject *self) {\n" "static PyObject * eager_get_core_ops_args_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n" " PyThreadState *tstate = nullptr;\n"
" try\n" " try {\n"
" {\n"
" return ToPyObject(core_ops_legacy_args_info);\n" " return ToPyObject(core_ops_legacy_args_info);\n"
" }\n" " } catch(...) {\n"
" catch(...) {\n"
" if (tstate) {\n" " if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n" " PyEval_RestoreThread(tstate);\n"
" }\n" " }\n"
...@@ -376,11 +372,9 @@ static std::string GenerateCoreOpsInfoMap() { ...@@ -376,11 +372,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n" "\n"
"static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n" "static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n" " PyThreadState *tstate = nullptr;\n"
" try\n" " try {\n"
" {\n"
" return ToPyObject(core_ops_legacy_args_type_info);\n" " return ToPyObject(core_ops_legacy_args_type_info);\n"
" }\n" " } catch(...) {\n"
" catch(...) {\n"
" if (tstate) {\n" " if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n" " PyEval_RestoreThread(tstate);\n"
" }\n" " }\n"
...@@ -391,11 +385,9 @@ static std::string GenerateCoreOpsInfoMap() { ...@@ -391,11 +385,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n" "\n"
"static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n" "static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n" " PyThreadState *tstate = nullptr;\n"
" try\n" " try {\n"
" {\n"
" return ToPyObject(core_ops_legacy_returns_info);\n" " return ToPyObject(core_ops_legacy_returns_info);\n"
" }\n" " } catch(...) {\n"
" catch(...) {\n"
" if (tstate) {\n" " if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n" " PyEval_RestoreThread(tstate);\n"
" }\n" " }\n"
...@@ -516,10 +508,10 @@ int main(int argc, char* argv[]) { ...@@ -516,10 +508,10 @@ int main(int argc, char* argv[]) {
auto op_funcs = GenerateOpFunctions(); auto op_funcs = GenerateOpFunctions();
auto core_ops_infos = GenerateCoreOpsInfoMap(); auto core_ops_infos = GenerateCoreOpsInfoMap();
std::string core_ops_infos_registry = std::string core_ops_infos_registry =
"{\"get_core_ops_args_info\", " " {\"get_core_ops_args_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, " "(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_info.\"},\n" "\"C++ interface function for eager_get_core_ops_args_info.\"},\n"
"{\"get_core_ops_args_type_info\", " " {\"get_core_ops_args_type_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, " "(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, "
"METH_NOARGS, " "METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n" "\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n"
...@@ -553,7 +545,7 @@ int main(int argc, char* argv[]) { ...@@ -553,7 +545,7 @@ int main(int argc, char* argv[]) {
"core.eager.ops failed!\"));\n" "core.eager.ops failed!\"));\n"
<< " }\n\n" << " }\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n\n" << " BindFinalStateEagerOpFunctions(&m);\n"
<< "}\n\n" << "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册