未验证 提交 56eedf27 编写于 作者: W wanghuancoder 提交者: GitHub

refine eager code gen (#45540)

上级 32f42e94
......@@ -1054,7 +1054,7 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta \n";
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta\n";
std::string get_output_autograd_meta_str = "";
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
......@@ -1390,7 +1390,7 @@ static std::string GenerateGradNodeCreationContent(
"paddle::platform::TracerEventType::OperatorInner, 1);\n"
"%s"
" if(require_any_grad) {\n"
" VLOG(6) << \" Construct Grad for %s \"; \n"
" VLOG(6) << \" Construct Grad for %s \";\n"
" egr::EagerUtils::PassStopGradient(%s);\n"
" %s\n"
" }\n"
......@@ -1750,7 +1750,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" if (egr::Controller::Instance().GetAMPLevel() != "
"paddle::imperative::AmpLevel::O0) {\n"
" VLOG(5) << \"Check and Prepare For AMP\";\n"
" \n"
" \n"
"%s\n"
" }\n";
std::string amp_logic_str = "";
......@@ -1875,7 +1875,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n"
" egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, "
"outs, attrs, \n"
"outs, attrs,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {%s});\n";
std::string trace_op_str = paddle::string::Sprintf(
......@@ -2152,7 +2152,7 @@ static std::string GenerateSingleOpBase(
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* FILL_ZERO_TEMPLATE =
"egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], "
" egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[%d], "
"this->InputMeta()[%d]);\n";
fill_zero_str += paddle::string::Sprintf(
FILL_ZERO_TEMPLATE, fwd_output_position, fwd_output_position);
......@@ -2385,9 +2385,9 @@ static std::string GenerateSingleOpBase(
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(%s[%d])});} \n ";
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(%s[%d])});}\n";
outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grads_position,
grads_position,
......@@ -2406,7 +2406,7 @@ static std::string GenerateSingleOpBase(
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if(!out_metas[%d].empty()){ %s.insert({ \"%s\", "
"egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n ";
"egr::EagerUtils::CreateVars(out_metas[%d].size())});}\n";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position,
......@@ -2415,10 +2415,10 @@ static std::string GenerateSingleOpBase(
fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
" if((!out_metas[%d].empty()) && "
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
").GenerateUniqueName())}});} \n ";
").GenerateUniqueName())}});}\n";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
fwd_input_position,
......@@ -2565,7 +2565,7 @@ static std::string GenerateSingleOpBase(
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE =
" if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
" if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
outs_name,
......@@ -2754,7 +2754,7 @@ static std::string GenerateGradNodeCCContents(
" const auto& out_metas = OutputMeta();\n"
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> outputs(%d);\n"
" %s\n"
"%s\n"
" if(NeedComplexToRealConversion()) "
"HandleComplexGradToRealGrad(&outputs);\n"
" return outputs;\n";
......@@ -2813,17 +2813,17 @@ static std::string GenerateGradNodeHeaderContents(
"create_graph = false, bool is_new_grad = false) "
"override;\n"
"\n"
" void ClearTensorWrappers() override { \n"
" void ClearTensorWrappers() override {\n"
"%s\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \"%sGradNodeCompat\"; } \n "
" std::string name() override { return \"%sGradNodeCompat\"; }\n"
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
"std::shared_ptr<GradNodeBase> Copy() const override {{\n"
" auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(*this));\n "
" return copied_node;\n "
"}}\n "
"%sGradNodeCompat(*this));\n"
" return copied_node;\n"
"}}\n"
"\n"
" // SetX, SetY, ...\n"
"%s\n"
......@@ -2838,12 +2838,12 @@ static std::string GenerateGradNodeHeaderContents(
// [Generation] Handle Attributes
std::string set_attr_map_str =
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n "
"attr_map_ = std::move(attr_map);\n }\n";
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n "
"attr_map_ = std::move(attr_map);\n }\n";
set_attr_map_str +=
" void SetDefaultAttrMap(paddle::framework::AttributeMap&& "
"default_attr_map) {\n default_attr_map_ = "
"std::move(default_attr_map);\n }\n";
"default_attr_map) {\n default_attr_map_ = "
"std::move(default_attr_map);\n }\n";
std::string attr_members_str =
" paddle::framework::AttributeMap attr_map_;\n";
attr_members_str += " paddle::framework::AttributeMap default_attr_map_;";
......@@ -2935,7 +2935,7 @@ static std::string GenerateGradNodeHeaderContents(
CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name);
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str +=
paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE,
tensor_wrapper_name,
......
......@@ -87,11 +87,9 @@ PYTHON_C_FUNCTION_TEMPLATE = \
"""
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
{}
PyThreadState *tstate = nullptr;
try {{
VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args
{}
// Parse Attributes if needed
......@@ -116,7 +114,7 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
}}
"""
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});"
FUNCTION_SET_DEVICE_TEMPLATE = \
......@@ -145,10 +143,7 @@ FUNCTION_NAME_TEMPLATE = \
PYTHON_C_FUNCTION_REG_TEMPLATE = \
"""
{{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}}
"""
" {{\"{}{}\", (PyCFunction)(void(*)(void)) {}eager_api_{}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {} in dygraph.\"}},\n"
PYTHON_C_WRAPPER_TEMPLATE = \
......@@ -173,7 +168,7 @@ namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
{}
}};
void BindFinalStateEagerOpFunctions(pybind11::module *module) {{
......@@ -195,8 +190,7 @@ CORE_OPS_INFO = \
"""
static PyObject * eager_get_core_ops_args_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_args_info);
}
catch(...) {
......@@ -210,8 +204,7 @@ static PyObject * eager_get_core_ops_args_info(PyObject *self) {
static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_args_type_info);
}
catch(...) {
......@@ -225,8 +218,7 @@ static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {
static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
PyThreadState *tstate = nullptr;
try
{
try {
return ToPyObject(core_ops_returns_info);
}
catch(...) {
......@@ -242,16 +234,9 @@ static PyObject * eager_get_core_ops_returns_info(PyObject *self) {
CORE_OPS_INFO_REGISTRY = \
"""
{\"get_core_ops_args_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS,
\"C++ interface function for eager_get_core_ops_args_info.\"},
{\"get_core_ops_args_type_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info,
METH_NOARGS,
\"C++ interface function for eager_get_core_ops_args_type_info.\"},
{\"get_core_ops_returns_info\",
(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info,
METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
{\"get_core_ops_args_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_info.\"},
{\"get_core_ops_args_type_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_args_type_info.\"},
{\"get_core_ops_returns_info\", (PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, METH_NOARGS, \"C++ interface function for eager_get_core_ops_returns_info.\"},
"""
NAMESPACE_WRAPPER_TEMPLATE = \
......@@ -429,7 +414,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
else:
self.python_c_function_str += python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str
self.python_c_function_reg_str += python_c_inplace_func_reg_str
def run(self):
# Initialized is_forward_only
......@@ -480,7 +465,7 @@ class PythonCGenerator(GeneratorBase):
if status == True:
self.python_c_functions_str += f_generator.python_c_function_str + "\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str
def AttachNamespace(self):
namespace = self.namespace
......@@ -530,7 +515,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
python_c_function_str += core_ops_infos_definition
python_c_function_reg_str += core_ops_infos_registry
python_c_function_reg_str += "\n {nullptr,nullptr,0,nullptr}"
python_c_function_reg_str += " {nullptr,nullptr,0,nullptr}"
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str)
......@@ -556,7 +541,7 @@ if __name__ == "__main__":
py_c_generator.run()
generated_python_c_functions += py_c_generator.python_c_functions_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str + "\n"
generated_python_c_registration += py_c_generator.python_c_functions_reg_str
python_c_str = GeneratePythonCWrappers(generated_python_c_functions,
generated_python_c_registration)
......
......@@ -113,8 +113,7 @@ R"(
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyThreadState *tstate = nullptr;
try
{
try {
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
......@@ -123,8 +122,7 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
PyEval_RestoreThread(tstate);
tstate = nullptr;
%s
}
catch(...) {
} catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
......@@ -361,11 +359,9 @@ static std::string GenerateCoreOpsInfoMap() {
std::string result =
"static PyObject * eager_get_core_ops_args_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_args_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -376,11 +372,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n"
"static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_args_type_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -391,11 +385,9 @@ static std::string GenerateCoreOpsInfoMap() {
"\n"
"static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n"
" PyThreadState *tstate = nullptr;\n"
" try\n"
" {\n"
" try {\n"
" return ToPyObject(core_ops_legacy_returns_info);\n"
" }\n"
" catch(...) {\n"
" } catch(...) {\n"
" if (tstate) {\n"
" PyEval_RestoreThread(tstate);\n"
" }\n"
......@@ -516,10 +508,10 @@ int main(int argc, char* argv[]) {
auto op_funcs = GenerateOpFunctions();
auto core_ops_infos = GenerateCoreOpsInfoMap();
std::string core_ops_infos_registry =
"{\"get_core_ops_args_info\", "
" {\"get_core_ops_args_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_info.\"},\n"
"{\"get_core_ops_args_type_info\", "
" {\"get_core_ops_args_type_info\", "
"(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, "
"METH_NOARGS, "
"\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n"
......@@ -553,7 +545,7 @@ int main(int argc, char* argv[]) {
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n"
<< "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册