Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b9342a80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
b9342a80
编写于
5月 18, 2022
作者:
W
Weilong Wu
提交者:
GitHub
5月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Polish eager code generation (#42822)
* [Eager] Polish eager code generation * Remove useless code in codegen
上级
570d0322
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
80 addition
and
66 deletion
+80
-66
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
...uto_code_generator/final_state_generator/codegen_utils.py
+1
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+59
-50
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+14
-9
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+4
-4
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+2
-2
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
浏览文件 @
b9342a80
...
...
@@ -418,7 +418,7 @@ class FunctionGeneratorBase:
return_name
]
=
[
return_type
,
return_pos
]
class
Yaml
GeneratorBase
:
class
GeneratorBase
:
def
__init__
(
self
,
api_yaml_path
):
self
.
namespace
=
""
self
.
api_yaml_path
=
api_yaml_path
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
b9342a80
...
...
@@ -29,7 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu
from
codegen_utils
import
GetInplacedFunctionName
from
codegen_utils
import
ParseYamlArgs
,
ParseYamlReturns
,
ParseYamlForwardFromBackward
from
codegen_utils
import
ParseYamlForward
,
ParseYamlBackward
from
codegen_utils
import
FunctionGeneratorBase
,
Yaml
GeneratorBase
from
codegen_utils
import
FunctionGeneratorBase
,
GeneratorBase
from
codegen_utils
import
ops_to_fill_zero_for_empty_grads
from
codegen_utils
import
AssertMessage
,
GetIndent
...
...
@@ -60,14 +60,6 @@ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE
=
\
""" egr::TensorWrapper {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
\
""" {}.clear();
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE
=
\
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{
...
...
@@ -76,10 +68,18 @@ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE
=
\
""" egr::TensorWrapper {};
"""
VECTOR_TENSOR_MEMBER_TEMPLATE
=
\
""" std::vector<egr::TensorWrapper> {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
\
""" {}.clear();
"""
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE
=
\
""" for (auto& tw : {}) {{
tw.clear();
...
...
@@ -423,9 +423,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self
.
forward_returns_list
=
[
]
#[ [ret_name, ret_type, orig_position], ...]
self
.
backward_inputs_list
=
[
]
#[ [attr_name, attr_type, default_value, orig_position], ...]
self
.
backward_attrs_list
=
[
]
#[ [attr_name, attr_type, default_value, orig_position], ...]
self
.
backward_inputs_list
=
[
]
#[ [arg_name, arg_type, orig_position], ...]
self
.
backward_returns_list
=
[
]
#[ [ret_name, ret_type, orig_position], ...]
...
...
@@ -504,11 +504,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for
_
,
_
,
pos
in
forward_inputs_list
:
max_input_position
=
max
(
max_input_position
,
pos
)
max_attr_position
=
-
1
for
_
,
_
,
_
,
pos
in
forward_attrs_list
:
assert
pos
>
max_input_position
,
AssertMessage
(
pos
,
max_input_position
)
max_attr_position
=
max
(
max_attr_position
,
pos
)
def
BackwardValidationCheck
(
self
):
backward_forward_inputs_map
=
self
.
backward_forward_inputs_map
...
...
@@ -692,12 +690,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
else
:
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
);"
set_input_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
else
:
else
:
# Forwad's output as backward's input
if
num_fwd_outputs
>
1
:
# Aligned with forward output position
assert
name
in
forward_outputs_position_map
.
keys
(
),
AssertMessage
(
name
,
forward_outputs_position_map
.
keys
())
fwd_output_pos
=
forward_outputs_position_map
[
name
][
1
]
if
is_optional
:
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr()));"
...
...
@@ -733,7 +730,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_out_meta_list
.
append
(
set_grad_out_meta
)
set_grad_out_meta_str
=
"
\n
"
.
join
(
set_grad_out_meta_list
)
# SetOutRank & SetHistory & SetGradInMeta
# SetOutRank & SetHistory & SetGradInMeta
& CheckAndRetainGrad
set_out_rank_list
=
[]
set_history_list
=
[]
set_grad_in_meta_list
=
[]
...
...
@@ -741,11 +738,12 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs
=
len
(
forward_outputs_position_map
.
keys
())
for
name
,
(
_
,
pos
)
in
forward_outputs_position_map
.
items
():
output_autograd_meta_name
=
GetAutoGradMetaName
(
name
)
set_out_rank
=
f
"
{
indent
}
egr::EagerUtils::SetOutRankWithSlot(
{
output_autograd_meta_name
}
,
{
pos
}
);"
set_history
=
f
"
{
indent
}
egr::EagerUtils::SetHistory(
{
output_autograd_meta_name
}
, grad_node);"
set_retain_grad
=
f
"
{
indent
}
egr::EagerUtils::CheckAndRetainGrad(
{
name
}
);"
set_grad_in_meta
=
f
"
{
indent
}
grad_node->SetGradInMeta(
{
name
}
,
{
pos
}
);"
set_retain_grad
=
f
"
{
indent
}
egr::EagerUtils::CheckAndRetainGrad(
{
name
}
);"
set_out_rank_list
.
append
(
set_out_rank
)
set_history_list
.
append
(
set_history
)
set_grad_in_meta_list
.
append
(
set_grad_in_meta
)
...
...
@@ -806,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self
.
DetermineForwardPositionMap
(
self
.
forward_inputs_list
,
self
.
forward_returns_list
)
# Initialize
forward_inputs_position_map, forward_outputs_position
_map
# Initialize
backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs
_map
self
.
SlotNameMatching
()
# Backward Validation Check
...
...
@@ -822,18 +820,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self
.
forward_definition_str
=
""
self
.
forward_declaration_str
=
""
def
GenerateForwardDefinition
(
self
,
is_inplaced
):
def
GenerateForwardDefinition
AndDeclaration
(
self
,
is_inplaced
):
namespace
=
self
.
namespace
forward_api_name
=
GetInplacedFunctionName
(
self
.
forward_api_name
)
if
is_inplaced
else
self
.
forward_api_name
backward_api_name
=
self
.
backward_api_name
forward_inputs_position_map
=
self
.
forward_inputs_position_map
forward_outputs_position_map
=
self
.
forward_outputs_position_map
forward_attrs_list
=
self
.
forward_attrs_list
backward_forward_inputs_map
=
self
.
backward_forward_inputs_map
backward_grad_inputs_map
=
self
.
backward_grad_inputs_map
backward_grad_outputs_map
=
self
.
backward_grad_outputs_map
backward_attrs_list
=
self
.
backward_attrs_list
optional_inputs
=
self
.
optional_inputs
intermediate_outputs
=
self
.
intermediate_outputs
inplace_map
=
self
.
inplace_map
if
is_inplaced
else
{}
...
...
@@ -845,6 +841,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_args_definition_list
=
[
""
for
i
in
range
(
num_inputs
)]
inputs_args_declaration_list
=
[
""
for
i
in
range
(
num_inputs
)]
inputs_call_list
=
[
""
for
i
in
range
(
num_inputs
)]
amp_inputs_call_list
=
[
""
for
i
in
range
(
num_inputs
)]
amp_tensors_vector_list
=
[]
amp_tensors_vector_optional_list
=
[]
...
...
@@ -1019,9 +1016,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str
+=
BUMP_INPLACE_VERSION_TEMPLATE
.
format
(
inplace_name
,
inplace_name
)
# Node Creation
self
.
GenerateNodeCreationCodes
()
node_creation_str
=
self
.
node_creation_str
dygraph_event_str
=
f
"
{
indent
}
paddle::platform::RecordEvent dygraph_entrance_record_event(
\"
{
forward_api_name
}
dygraph
\"
, paddle::platform::TracerEventType::Operator, 1);
\n
"
forward_function_name
=
GetDygraphForwardFunctionName
(
forward_api_name
)
...
...
@@ -1045,6 +1043,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_tensors_vector_optional_list_str
,
amp_get_dst_dtype_str
,
amp_autocast_list_str
,
amp_call_str
)
# Generate forward_definition_str and forward_declaration_str
self
.
forward_definition_str
+=
FORWARD_FUNCTION_TEMPLATE
.
format
(
returns_type_str
,
forward_function_name
,
inputs_args_definition_str
,
dygraph_event_str
,
amp_logic_str
,
inputs_autograd_meta_str
,
...
...
@@ -1061,8 +1060,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if
forward_api_name
!=
"sum"
and
"inplace"
in
forward_api_contents
.
keys
(
):
#
Node Defini
tion Generation
self
.
GenerateForwardDefinition
(
is_inplaced
=
True
)
#
Function Definition and Declara
tion Generation
self
.
GenerateForwardDefinition
AndDeclaration
(
is_inplaced
=
True
)
self
.
UpdateCoreOpsInformation
(
is_inplaced
=
True
)
def
UpdateCoreOpsInformation
(
self
,
is_inplaced
):
...
...
@@ -1083,6 +1082,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
final_state_fwd_api_name
]
=
[
""
for
i
in
range
(
num_args
)]
core_ops_args_type_info
[
final_state_fwd_api_name
]
=
[
""
for
i
in
range
(
num_args
)]
for
name
,
(
ttype
,
pos
)
in
forward_inputs_position_map
.
items
():
core_ops_args_info
[
final_state_fwd_api_name
][
pos
]
=
name
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1104,7 +1104,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
#####################
## Code Generation ##
#####################
self
.
GenerateForwardDefinition
(
is_inplaced
=
False
)
# Definition And Declaration
self
.
GenerateForwardDefinitionAndDeclaration
(
is_inplaced
=
False
)
self
.
UpdateCoreOpsInformation
(
is_inplaced
=
False
)
...
...
@@ -1164,9 +1166,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_contents
=
self
.
grad_api_contents
next_grad_api_contents
=
self
.
next_grad_api_contents
grad_node_creation_str
=
""
grad_node_out_list
=
[]
next_
grad_node_creation_str
=
""
next_
grad_node_out_list
=
[]
if
next_grad_api_contents
:
# Fake forward_api_contents and backward_api_contents
forward_api_contents
=
grad_api_contents
forward_api_contents
[
'api'
]
=
forward_api_contents
[
'backward_api'
]
backward_api_contents
=
next_grad_api_contents
...
...
@@ -1175,12 +1178,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_contents
,
backward_api_contents
,
namespace
)
next_node_generator
.
run
()
next_node_generator
.
GenerateNodeCreationCodes
()
grad_node_creation_str
=
next_node_generator
.
node_creation_str
grad_node_out_list
=
next_node_generator
.
grad_node_out_list
next_
grad_node_creation_str
=
next_node_generator
.
node_creation_str
next_
grad_node_out_list
=
next_node_generator
.
grad_node_out_list
self
.
RecordGrad2NextGradNameMapping
(
next_node_generator
)
return
grad_node_creation_str
,
grad_node_out_list
return
next_grad_node_creation_str
,
next_
grad_node_out_list
def
GenerateNodeDeclaration
(
self
):
forward_op_name
=
self
.
forward_api_name
...
...
@@ -1188,7 +1191,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list
=
self
.
backward_attrs_list
no_need_buffers
=
self
.
no_need_buffers
# SetTensorWrapper Methods & TensorWrapper Members
# SetTensorWrapper Methods & TensorWrapper Members
& ClearTensorWrappers
set_tensor_wrapper_methods_str
=
""
tensor_wrapper_members_str
=
""
clear_tensor_wrapper_str
=
""
...
...
@@ -1241,8 +1244,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
def
GenerateNodeDefinition
(
self
,
grad_node_creation_str
,
grad_node_out_list
):
def
GenerateNodeDefinition
(
self
,
next_
grad_node_creation_str
,
next_
grad_node_out_list
):
namespace
=
self
.
namespace
forward_api_name
=
self
.
forward_api_name
backward_api_name
=
self
.
backward_api_name
...
...
@@ -1362,14 +1365,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str
=
""
outputs_autograd_meta_str
=
""
compute_require_grad_str
=
""
if
len
(
grad_node_creation_str
)
>
0
:
# 1. Get Input AutoGradMeta
if
len
(
next_
grad_node_creation_str
)
>
0
:
# 1. Get
Grad
Input AutoGradMeta
inputs_autograd_meta_list
=
[]
compute_require_grad_args_list
=
[
"trace_backward"
]
for
name
,
(
ttype
,
pos
,
grad_api_position
)
in
backward_grad_inputs_map
.
items
():
transformed_tensor_name
=
self
.
TransformToNextGradName
(
name
)
if
transformed_tensor_name
in
grad_node_out_list
:
if
transformed_tensor_name
in
next_
grad_node_out_list
:
input_autograd_meta_name
=
GetAutoGradMetaName
(
transformed_tensor_name
)
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1388,7 +1391,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# 2. Get TensorWrapper AutoGradMeta
for
name
,
(
ttype
,
_
,
pos
),
in
backward_forward_inputs_map
.
items
():
transformed_tensor_name
=
self
.
TransformToNextGradName
(
name
)
if
transformed_tensor_name
in
grad_node_out_list
:
if
transformed_tensor_name
in
next_
grad_node_out_list
:
input_autograd_meta_name
=
GetAutoGradMetaName
(
transformed_tensor_name
)
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1447,7 +1450,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name
,
fill_zero_str
,
get_grad_in_args_str
,
grad_node_name
,
grad_function_call_str
,
check_nan_inf_str
,
inputs_autograd_meta_str
,
outputs_autograd_meta_str
,
compute_require_grad_str
,
grad_node_creation_str
,
returns_str
)
next_
grad_node_creation_str
,
returns_str
)
def
run
(
self
):
super
().
run
()
...
...
@@ -1458,27 +1461,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ##
#####################
# Higher-order GradNode generation
grad_node_creation_str
,
grad_node_out_list
=
self
.
GenerateHigherOrderNodeCreationCode
(
next_grad_node_creation_str
,
next_
grad_node_out_list
=
self
.
GenerateHigherOrderNodeCreationCode
(
)
self
.
GenerateNodeDeclaration
()
self
.
GenerateNodeDefinition
(
grad_node_creation_str
,
grad_node_out_list
)
self
.
GenerateNodeDefinition
(
next_grad_node_creation_str
,
next_grad_node_out_list
)
class
Dygraph
YamlGenerator
(
Yaml
GeneratorBase
):
class
Dygraph
ForwardAndNodesGenerator
(
GeneratorBase
):
def
__init__
(
self
,
api_yaml_path
,
backward_yaml_path
):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
Yaml
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
self
.
backward_yaml_path
=
backward_yaml_path
self
.
grad_api_dict
=
{}
self
.
forward_definition_str
=
""
self
.
forward_declaration_str
=
""
self
.
forward_definition_str
=
""
self
.
node_declaration_str
=
""
self
.
node_definition_str
=
""
...
...
@@ -1518,6 +1523,7 @@ class DygraphYamlGenerator(YamlGeneratorBase):
self
.
forward_definition_str
+=
function_generator
.
forward_definition_str
+
"
\n
"
self
.
forward_declaration_str
+=
function_generator
.
forward_declaration_str
+
"
\n
"
# Generate Dygraph GradNode Function
while
True
:
next_grad_api_contents
=
self
.
GetBackwardAPIContents
(
backward_api_contents
)
...
...
@@ -1611,20 +1617,23 @@ if __name__ == "__main__":
# Generate per Dygraph API
node_declaration_str
=
""
node_definition_str
=
""
forward_definition_str
=
""
forward_declaration_str
=
""
forward_definition_str
=
""
for
i
in
range
(
len
(
api_yaml_paths
)):
api_yaml_path
=
api_yaml_paths
[
i
]
backward_yaml_path
=
backward_yaml_paths
[
i
]
generator
=
DygraphYamlGenerator
(
api_yaml_path
,
backward_yaml_path
)
generator
=
DygraphForwardAndNodesGenerator
(
api_yaml_path
,
backward_yaml_path
)
generator
.
run
()
node_declaration_str
+=
generator
.
node_declaration_str
+
"
\n
"
node_definition_str
+=
generator
.
node_definition_str
+
"
\n
"
forward_definition_str
+=
generator
.
forward_definition_str
+
"
\n
"
forward_declaration_str
+=
generator
.
forward_declaration_str
+
"
\n
"
forward_definition_str
+=
generator
.
forward_definition_str
+
"
\n
"
# Generate Files
nodes_h_path
=
args
.
nodes_h_path
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
b9342a80
...
...
@@ -15,7 +15,7 @@
import
os
import
argparse
import
logging
from
codegen_utils
import
FunctionGeneratorBase
,
Yaml
GeneratorBase
from
codegen_utils
import
FunctionGeneratorBase
,
GeneratorBase
from
codegen_utils
import
yaml_types_mapping
from
codegen_utils
import
ReadFwdFile
,
IsVectorTensorType
,
GetForwardFunctionName
from
codegen_utils
import
ParseYamlForward
,
GetInplacedFunctionName
...
...
@@ -100,6 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID
{}
// Call dygraph function
decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate);
...
...
@@ -341,6 +342,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Generate Record Event for performance profiling
pythonc_record_event_str
=
RECORD_EVENT_TEMPLATE
.
format
(
"pythonc_record_event"
,
forward_api_name
,
"pybind_imperative_func"
)
# Generate Python-C Function Definetion
self
.
python_c_function_str
=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
forward_api_name
,
pythonc_record_event_str
,
forward_api_name
,
get_eager_tensor_str
,
parse_attributes_str
,
set_device_str
,
...
...
@@ -350,6 +353,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Set prefix of forward_api_name to avoid conflicts
prefix
=
self
.
namespace
.
strip
(
"::"
)
forward_api_name_prefix
=
""
if
prefix
==
""
else
prefix
+
"_"
# Generate Python-C Function Registration
self
.
python_c_function_reg_str
=
PYTHON_C_FUNCTION_REG_TEMPLATE
.
format
(
forward_api_name_prefix
,
forward_api_name
,
namespace
,
...
...
@@ -376,6 +380,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name
,
inplace_output
)
break
# Generate Python-C Function Definetion
self
.
python_c_function_str
+=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
inplaced_forward_api_name
,
pythonc_record_event_str
,
inplaced_forward_api_name
,
get_eager_tensor_str
,
...
...
@@ -414,17 +419,17 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return
True
class
PythonC
YamlGenerator
(
Yaml
GeneratorBase
):
class
PythonC
Generator
(
GeneratorBase
):
def
__init__
(
self
,
path
):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
Yaml
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
# Generated Result
self
.
python_c_functions_reg_str
=
""
self
.
python_c_functions_str
=
""
self
.
python_c_functions_reg_str
=
""
def
GeneratePythonCFunctions
(
self
):
namespace
=
self
.
namespace
...
...
@@ -436,8 +441,8 @@ class PythonCYamlGenerator(YamlGeneratorBase):
status
=
f_generator
.
run
()
if
status
==
True
:
self
.
python_c_functions_reg_str
+=
f_generator
.
python_c_function_reg_str
+
",
\n
"
self
.
python_c_functions_str
+=
f_generator
.
python_c_function_str
+
"
\n
"
self
.
python_c_functions_reg_str
+=
f_generator
.
python_c_function_reg_str
+
",
\n
"
def
AttachNamespace
(
self
):
namespace
=
self
.
namespace
...
...
@@ -509,11 +514,11 @@ if __name__ == "__main__":
for
i
in
range
(
len
(
api_yaml_paths
)):
api_yaml_path
=
api_yaml_paths
[
i
]
y_generator
=
PythonCYaml
Generator
(
api_yaml_path
)
y
_generator
.
run
()
py_c_generator
=
PythonC
Generator
(
api_yaml_path
)
py_c
_generator
.
run
()
generated_python_c_functions
+=
y
_generator
.
python_c_functions_str
+
"
\n
"
generated_python_c_registration
+=
y
_generator
.
python_c_functions_reg_str
+
"
\n
"
generated_python_c_functions
+=
py_c
_generator
.
python_c_functions_str
+
"
\n
"
generated_python_c_registration
+=
py_c
_generator
.
python_c_functions_reg_str
+
"
\n
"
python_c_str
=
GeneratePythonCWrappers
(
generated_python_c_functions
,
generated_python_c_registration
)
...
...
python/paddle/utils/code_gen/api_base.py
浏览文件 @
b9342a80
...
...
@@ -434,7 +434,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
vars_list
=
kernel
[
'data_type'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows
2
, but received
{
len
(
vars_list
)
}
."
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows
1
, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataType(
{
vars_list
[
0
].
strip
()
}
);
"""
...
...
@@ -837,10 +837,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return
api_code
else
:
inv
e
ke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inv
e
ke_func_name
in
self
.
attrs
[
'names'
]:
inv
o
ke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inv
o
ke_func_name
in
self
.
attrs
[
'names'
]:
# Adjust the param whose name is same with api invoked.
pattern
=
r
'\W'
+
inv
e
ke_func_name
+
'[^A-Za-z0-9_(]'
pattern
=
r
'\W'
+
inv
o
ke_func_name
+
'[^A-Za-z0-9_(]'
def
adjust_name
(
matched
):
matched_str
=
matched
.
group
()
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
b9342a80
...
...
@@ -172,8 +172,8 @@ class BackwardAPI(BaseAPI):
return
kernel_output
,
output_names
,
output_create
def
gene_invoke_code
(
self
,
invoke_code
,
params_code
):
inv
e
ke_func_name
=
invoke_code
.
split
(
'('
)[
0
].
strip
()
if
inv
eke_func_name
.
endswith
(
'_grad'
)
or
inve
ke_func_name
.
endswith
(
inv
o
ke_func_name
=
invoke_code
.
split
(
'('
)[
0
].
strip
()
if
inv
oke_func_name
.
endswith
(
'_grad'
)
or
invo
ke_func_name
.
endswith
(
'_grad_impl'
):
return
f
"""
PADDLE_API
{
self
.
get_return_type
()
}
{
self
.
api
}
(
{
params_code
}
) {{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录