未验证 提交 56eedf27 编写于 作者: W wanghuancoder 提交者: GitHub

refine eager code gen (#45540)

上级 32f42e94
......@@ -1054,7 +1054,7 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta \n";
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta\n";
std::string get_output_autograd_meta_str = "";
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
......@@ -1390,7 +1390,7 @@ static std::string GenerateGradNodeCreationContent(
"paddle::platform::TracerEventType::OperatorInner, 1);\n"
"%s"
" if(require_any_grad) {\n"
" VLOG(6) << \" Construct Grad for %s \"; \n"
" VLOG(6) << \" Construct Grad for %s \";\n"
" egr::EagerUtils::PassStopGradient(%s);\n"
" %s\n"
" }\n"
......@@ -1875,7 +1875,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n"
" egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, "
"outs, attrs, \n"
"outs, attrs,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {%s});\n";
std::string trace_op_str = paddle::string::Sprintf(
......@@ -2152,7 +2152,7 @@ static std::string GenerateSingleOpBase(
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* FILL_ZERO_TEMPLATE =
"egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], "
" egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], "
"this->InputMeta()[%d]);\n";
fill_zero_str += paddle::string::Sprintf(
FILL_ZERO_TEMPLATE, fwd_output_position, fwd_output_position);
......@@ -2386,8 +2386,8 @@ static std::string GenerateSingleOpBase(
const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(%s[%d])});} \n ";
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(%s[%d])});}\n";
outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grads_position,
grads_position,
......@@ -2406,7 +2406,7 @@ static std::string GenerateSingleOpBase(
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if(!out_metas[%d].empty()){ %s.insert({ \"%s\", "
"egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n ";
"egr::EagerUtils::CreateVars(out_metas[%d].size())});}\n";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position,
......@@ -2418,7 +2418,7 @@ static std::string GenerateSingleOpBase(
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
").GenerateUniqueName())}});} \n ";
").GenerateUniqueName())}});}\n";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position,
......@@ -2754,7 +2754,7 @@ static std::string GenerateGradNodeCCContents(
" const auto& out_metas = OutputMeta();\n"
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> outputs(%d);\n"
" %s\n"
"%s\n"
" if(NeedComplexToRealConversion()) "
"HandleComplexGradToRealGrad(&outputs);\n"
" return outputs;\n";
......@@ -2813,17 +2813,17 @@ static std::string GenerateGradNodeHeaderContents(
"create_graph = false, bool is_new_grad = false) "
"override;\n"
"\n"
" void ClearTensorWrappers() override { \n"
" void ClearTensorWrappers() override {\n"
"%s\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \"%sGradNodeCompat\"; } \n "
" std::string name() override { return \"%sGradNodeCompat\"; }\n"
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
"std::shared_ptr<GradNodeBase> Copy() const override {{\n"
" auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(*this));\n "
" return copied_node;\n "
"}}\n "
"%sGradNodeCompat(*this));\n"
" return copied_node;\n"
"}}\n"
"\n"
" // SetX, SetY, ...\n"
"%s\n"
......
......@@ -87,11 +87,9 @@ PYTHON_C_FUNCTION_TEMPLATE = \
"""
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
{}
PyThreadState *tstate = nullptr;
try {{
VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args
{}
// Parse Attributes if needed
......@@ -116,7 +114,7 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
}}
"""
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});"
FUNCTION_SET_DEVICE_TEMPLATE = \
......@@ -145,10 +143,7 @@ FUNCTION_NAME_TEMPLATE = \
PYTHON_C_FUNCTION_REG_TEMPLATE = \
"""
{{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}
"""
" {{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}},\n"
PYTHON_C_WRAPPER_TEMPLATE = \
......@@ -173,7 +168,7 @@ namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
{}
}};
void BindFinalStateEagerOpFunctions(pybind11::module *module) {{
......@@ -195,8 +190,7 @@ CORE_OPS_INFO = \
"""
static PyObject * eager_get_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_args_info);
}
catch(...) {
......@@ -210,8 +204,7 @@ static PyObject * eager_get_core_ops_args_info(PyObject *self) {
static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_args_type_info);
}
catch(...) {
......@@ -225,8 +218,7 @@ static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_returns_info);
}
catch(...) {
......@@ -242,16 +234,9 @@ static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
CORE_OPS_INFO_REGISTRY = \
"""
{\"get_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_core_ops_args_info.\"},
{\"get_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_core_ops_args_type_info.\"},
{\"get_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
{\"get_core_ops_args_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_info.\"},
{\"get_core_ops_args_type_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_type_info.\"},
{\"get_core_ops_returns_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
"""
NAMESPACE_WRAPPER_TEMPLATE = \
......@@ -429,7 +414,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
else:
self.python_c_function_str += python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str
self.python_c_function_reg_str += python_c_inplace_func_reg_str
def run(self):
# Initialized is_forward_only
......@@ -480,7 +465,7 @@ class PythonCGenerator(GeneratorBase):
if status == True:
self.python_c_functions_str += f_generator.python_c_function_str + "\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str
def AttachNamespace(self):
namespace = self.namespace
......@@ -530,7 +515,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}"
python_c_function_reg_str += " {nullptr,nullptr,0,nullptr}"
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str)
......@@ -556,7 +541,7 @@ if __name__ == "__main__":
py_c_generator.run()
generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str
python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
generated_python_c_registration)
......
......@@ -113,8 +113,7 @@ R"(
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyThreadState *tstate = nullptr;
try
{
try {
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
......@@ -123,8 +122,7 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyEval_RestoreThread(tstate);
tstate = nullptr;
%s
}
catch(...) {
} catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
......@@ -361,11 +359,9 @@ static std::string GenerateCoreOpsInfoMap() {
std::string result =
"static PyObject * eager_get_core_ops_args_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_args_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -376,11 +372,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n"
"static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_args_type_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -391,11 +385,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n"
"static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_returns_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -516,10 +508,10 @@ int main(int argc, char* argv[]) {
auto op_funcs = GenerateOpFunctions();
auto core_ops_infos = GenerateCoreOpsInfoMap();
std::string core_ops_infos_registry =
"{\"get_core_ops_args_info\", "
" {\"get_core_ops_args_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_info.\"},\n"
"{\"get_core_ops_args_type_info\", "
" {\"get_core_ops_args_type_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, "
"METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n"
......@@ -553,7 +545,7 @@ int main(int argc, char* argv[]) {
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n"
<< "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册