Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b9342a80
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b9342a80
编写于
5月 18, 2022
作者:
W
Weilong Wu
提交者:
GitHub
5月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Polish eager code generation (#42822)
* [Eager] Polish eager code generation * Remove useless code in codegen
上级
570d0322
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
80 addition
and
66 deletion
+80
-66
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
...uto_code_generator/final_state_generator/codegen_utils.py
+1
-1
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+59
-50
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+14
-9
python/paddle/utils/code_gen/api_base.py
python/paddle/utils/code_gen/api_base.py
+4
-4
python/paddle/utils/code_gen/backward_api_gen.py
python/paddle/utils/code_gen/backward_api_gen.py
+2
-2
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
浏览文件 @
b9342a80
...
...
@@ -418,7 +418,7 @@ class FunctionGeneratorBase:
return_name
]
=
[
return_type
,
return_pos
]
class
Yaml
GeneratorBase
:
class
GeneratorBase
:
def
__init__
(
self
,
api_yaml_path
):
self
.
namespace
=
""
self
.
api_yaml_path
=
api_yaml_path
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
b9342a80
...
...
@@ -29,7 +29,7 @@ from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFu
from
codegen_utils
import
GetInplacedFunctionName
from
codegen_utils
import
ParseYamlArgs
,
ParseYamlReturns
,
ParseYamlForwardFromBackward
from
codegen_utils
import
ParseYamlForward
,
ParseYamlBackward
from
codegen_utils
import
FunctionGeneratorBase
,
Yaml
GeneratorBase
from
codegen_utils
import
FunctionGeneratorBase
,
GeneratorBase
from
codegen_utils
import
ops_to_fill_zero_for_empty_grads
from
codegen_utils
import
AssertMessage
,
GetIndent
...
...
@@ -60,14 +60,6 @@ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE
=
\
""" egr::TensorWrapper {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
\
""" {}.clear();
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE
=
\
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{
...
...
@@ -76,10 +68,18 @@ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
}}
"""
PLAIN_TENSOR_MEMBER_TEMPLATE
=
\
""" egr::TensorWrapper {};
"""
VECTOR_TENSOR_MEMBER_TEMPLATE
=
\
""" std::vector<egr::TensorWrapper> {};
"""
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
\
""" {}.clear();
"""
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE
=
\
""" for (auto& tw : {}) {{
tw.clear();
...
...
@@ -423,9 +423,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self
.
forward_returns_list
=
[
]
#[ [ret_name, ret_type, orig_position], ...]
self
.
backward_inputs_list
=
[
]
#[ [attr_name, attr_type, default_value, orig_position], ...]
self
.
backward_attrs_list
=
[
]
#[ [attr_name, attr_type, default_value, orig_position], ...]
self
.
backward_inputs_list
=
[
]
#[ [arg_name, arg_type, orig_position], ...]
self
.
backward_returns_list
=
[
]
#[ [ret_name, ret_type, orig_position], ...]
...
...
@@ -504,11 +504,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
for
_
,
_
,
pos
in
forward_inputs_list
:
max_input_position
=
max
(
max_input_position
,
pos
)
max_attr_position
=
-
1
for
_
,
_
,
_
,
pos
in
forward_attrs_list
:
assert
pos
>
max_input_position
,
AssertMessage
(
pos
,
max_input_position
)
max_attr_position
=
max
(
max_attr_position
,
pos
)
def
BackwardValidationCheck
(
self
):
backward_forward_inputs_map
=
self
.
backward_forward_inputs_map
...
...
@@ -692,12 +690,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
else
:
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
);"
set_input_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
else
:
else
:
# Forwad's output as backward's input
if
num_fwd_outputs
>
1
:
# Aligned with forward output position
assert
name
in
forward_outputs_position_map
.
keys
(
),
AssertMessage
(
name
,
forward_outputs_position_map
.
keys
())
fwd_output_pos
=
forward_outputs_position_map
[
name
][
1
]
if
is_optional
:
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr()));"
...
...
@@ -733,7 +730,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_out_meta_list
.
append
(
set_grad_out_meta
)
set_grad_out_meta_str
=
"
\n
"
.
join
(
set_grad_out_meta_list
)
# SetOutRank & SetHistory & SetGradInMeta
# SetOutRank & SetHistory & SetGradInMeta
& CheckAndRetainGrad
set_out_rank_list
=
[]
set_history_list
=
[]
set_grad_in_meta_list
=
[]
...
...
@@ -741,11 +738,12 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
num_outputs
=
len
(
forward_outputs_position_map
.
keys
())
for
name
,
(
_
,
pos
)
in
forward_outputs_position_map
.
items
():
output_autograd_meta_name
=
GetAutoGradMetaName
(
name
)
set_out_rank
=
f
"
{
indent
}
egr::EagerUtils::SetOutRankWithSlot(
{
output_autograd_meta_name
}
,
{
pos
}
);"
set_history
=
f
"
{
indent
}
egr::EagerUtils::SetHistory(
{
output_autograd_meta_name
}
, grad_node);"
set_retain_grad
=
f
"
{
indent
}
egr::EagerUtils::CheckAndRetainGrad(
{
name
}
);"
set_grad_in_meta
=
f
"
{
indent
}
grad_node->SetGradInMeta(
{
name
}
,
{
pos
}
);"
set_retain_grad
=
f
"
{
indent
}
egr::EagerUtils::CheckAndRetainGrad(
{
name
}
);"
set_out_rank_list
.
append
(
set_out_rank
)
set_history_list
.
append
(
set_history
)
set_grad_in_meta_list
.
append
(
set_grad_in_meta
)
...
...
@@ -806,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self
.
DetermineForwardPositionMap
(
self
.
forward_inputs_list
,
self
.
forward_returns_list
)
# Initialize
forward_inputs_position_map, forward_outputs_position
_map
# Initialize
backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs
_map
self
.
SlotNameMatching
()
# Backward Validation Check
...
...
@@ -822,18 +820,16 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self
.
forward_definition_str
=
""
self
.
forward_declaration_str
=
""
def
GenerateForwardDefinition
(
self
,
is_inplaced
):
def
GenerateForwardDefinition
AndDeclaration
(
self
,
is_inplaced
):
namespace
=
self
.
namespace
forward_api_name
=
GetInplacedFunctionName
(
self
.
forward_api_name
)
if
is_inplaced
else
self
.
forward_api_name
backward_api_name
=
self
.
backward_api_name
forward_inputs_position_map
=
self
.
forward_inputs_position_map
forward_outputs_position_map
=
self
.
forward_outputs_position_map
forward_attrs_list
=
self
.
forward_attrs_list
backward_forward_inputs_map
=
self
.
backward_forward_inputs_map
backward_grad_inputs_map
=
self
.
backward_grad_inputs_map
backward_grad_outputs_map
=
self
.
backward_grad_outputs_map
backward_attrs_list
=
self
.
backward_attrs_list
optional_inputs
=
self
.
optional_inputs
intermediate_outputs
=
self
.
intermediate_outputs
inplace_map
=
self
.
inplace_map
if
is_inplaced
else
{}
...
...
@@ -845,6 +841,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_args_definition_list
=
[
""
for
i
in
range
(
num_inputs
)]
inputs_args_declaration_list
=
[
""
for
i
in
range
(
num_inputs
)]
inputs_call_list
=
[
""
for
i
in
range
(
num_inputs
)]
amp_inputs_call_list
=
[
""
for
i
in
range
(
num_inputs
)]
amp_tensors_vector_list
=
[]
amp_tensors_vector_optional_list
=
[]
...
...
@@ -1019,9 +1016,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str
+=
BUMP_INPLACE_VERSION_TEMPLATE
.
format
(
inplace_name
,
inplace_name
)
# Node Creation
self
.
GenerateNodeCreationCodes
()
node_creation_str
=
self
.
node_creation_str
dygraph_event_str
=
f
"
{
indent
}
paddle::platform::RecordEvent dygraph_entrance_record_event(
\"
{
forward_api_name
}
dygraph
\"
, paddle::platform::TracerEventType::Operator, 1);
\n
"
forward_function_name
=
GetDygraphForwardFunctionName
(
forward_api_name
)
...
...
@@ -1045,6 +1043,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_tensors_vector_optional_list_str
,
amp_get_dst_dtype_str
,
amp_autocast_list_str
,
amp_call_str
)
# Generate forward_definition_str and forward_declaration_str
self
.
forward_definition_str
+=
FORWARD_FUNCTION_TEMPLATE
.
format
(
returns_type_str
,
forward_function_name
,
inputs_args_definition_str
,
dygraph_event_str
,
amp_logic_str
,
inputs_autograd_meta_str
,
...
...
@@ -1061,8 +1060,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if
forward_api_name
!=
"sum"
and
"inplace"
in
forward_api_contents
.
keys
(
):
#
Node Defini
tion Generation
self
.
GenerateForwardDefinition
(
is_inplaced
=
True
)
#
Function Definition and Declara
tion Generation
self
.
GenerateForwardDefinition
AndDeclaration
(
is_inplaced
=
True
)
self
.
UpdateCoreOpsInformation
(
is_inplaced
=
True
)
def
UpdateCoreOpsInformation
(
self
,
is_inplaced
):
...
...
@@ -1083,6 +1082,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
final_state_fwd_api_name
]
=
[
""
for
i
in
range
(
num_args
)]
core_ops_args_type_info
[
final_state_fwd_api_name
]
=
[
""
for
i
in
range
(
num_args
)]
for
name
,
(
ttype
,
pos
)
in
forward_inputs_position_map
.
items
():
core_ops_args_info
[
final_state_fwd_api_name
][
pos
]
=
name
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1104,7 +1104,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
#####################
## Code Generation ##
#####################
self
.
GenerateForwardDefinition
(
is_inplaced
=
False
)
# Definition And Declaration
self
.
GenerateForwardDefinitionAndDeclaration
(
is_inplaced
=
False
)
self
.
UpdateCoreOpsInformation
(
is_inplaced
=
False
)
...
...
@@ -1164,9 +1166,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_contents
=
self
.
grad_api_contents
next_grad_api_contents
=
self
.
next_grad_api_contents
grad_node_creation_str
=
""
grad_node_out_list
=
[]
next_
grad_node_creation_str
=
""
next_
grad_node_out_list
=
[]
if
next_grad_api_contents
:
# Fake forward_api_contents and backward_api_contents
forward_api_contents
=
grad_api_contents
forward_api_contents
[
'api'
]
=
forward_api_contents
[
'backward_api'
]
backward_api_contents
=
next_grad_api_contents
...
...
@@ -1175,12 +1178,12 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
forward_api_contents
,
backward_api_contents
,
namespace
)
next_node_generator
.
run
()
next_node_generator
.
GenerateNodeCreationCodes
()
grad_node_creation_str
=
next_node_generator
.
node_creation_str
grad_node_out_list
=
next_node_generator
.
grad_node_out_list
next_
grad_node_creation_str
=
next_node_generator
.
node_creation_str
next_
grad_node_out_list
=
next_node_generator
.
grad_node_out_list
self
.
RecordGrad2NextGradNameMapping
(
next_node_generator
)
return
grad_node_creation_str
,
grad_node_out_list
return
next_grad_node_creation_str
,
next_
grad_node_out_list
def
GenerateNodeDeclaration
(
self
):
forward_op_name
=
self
.
forward_api_name
...
...
@@ -1188,7 +1191,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list
=
self
.
backward_attrs_list
no_need_buffers
=
self
.
no_need_buffers
# SetTensorWrapper Methods & TensorWrapper Members
# SetTensorWrapper Methods & TensorWrapper Members
& ClearTensorWrappers
set_tensor_wrapper_methods_str
=
""
tensor_wrapper_members_str
=
""
clear_tensor_wrapper_str
=
""
...
...
@@ -1241,8 +1244,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
def
GenerateNodeDefinition
(
self
,
grad_node_creation_str
,
grad_node_out_list
):
def
GenerateNodeDefinition
(
self
,
next_
grad_node_creation_str
,
next_
grad_node_out_list
):
namespace
=
self
.
namespace
forward_api_name
=
self
.
forward_api_name
backward_api_name
=
self
.
backward_api_name
...
...
@@ -1362,14 +1365,14 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_str
=
""
outputs_autograd_meta_str
=
""
compute_require_grad_str
=
""
if
len
(
grad_node_creation_str
)
>
0
:
# 1. Get Input AutoGradMeta
if
len
(
next_
grad_node_creation_str
)
>
0
:
# 1. Get
Grad
Input AutoGradMeta
inputs_autograd_meta_list
=
[]
compute_require_grad_args_list
=
[
"trace_backward"
]
for
name
,
(
ttype
,
pos
,
grad_api_position
)
in
backward_grad_inputs_map
.
items
():
transformed_tensor_name
=
self
.
TransformToNextGradName
(
name
)
if
transformed_tensor_name
in
grad_node_out_list
:
if
transformed_tensor_name
in
next_
grad_node_out_list
:
input_autograd_meta_name
=
GetAutoGradMetaName
(
transformed_tensor_name
)
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1388,7 +1391,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# 2. Get TensorWrapper AutoGradMeta
for
name
,
(
ttype
,
_
,
pos
),
in
backward_forward_inputs_map
.
items
():
transformed_tensor_name
=
self
.
TransformToNextGradName
(
name
)
if
transformed_tensor_name
in
grad_node_out_list
:
if
transformed_tensor_name
in
next_
grad_node_out_list
:
input_autograd_meta_name
=
GetAutoGradMetaName
(
transformed_tensor_name
)
if
IsPlainTensorType
(
ttype
):
...
...
@@ -1447,7 +1450,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name
,
fill_zero_str
,
get_grad_in_args_str
,
grad_node_name
,
grad_function_call_str
,
check_nan_inf_str
,
inputs_autograd_meta_str
,
outputs_autograd_meta_str
,
compute_require_grad_str
,
grad_node_creation_str
,
returns_str
)
next_
grad_node_creation_str
,
returns_str
)
def
run
(
self
):
super
().
run
()
...
...
@@ -1458,27 +1461,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ##
#####################
# Higher-order GradNode generation
grad_node_creation_str
,
grad_node_out_list
=
self
.
GenerateHigherOrderNodeCreationCode
(
next_grad_node_creation_str
,
next_
grad_node_out_list
=
self
.
GenerateHigherOrderNodeCreationCode
(
)
self
.
GenerateNodeDeclaration
()
self
.
GenerateNodeDefinition
(
grad_node_creation_str
,
grad_node_out_list
)
self
.
GenerateNodeDefinition
(
next_grad_node_creation_str
,
next_grad_node_out_list
)
class
Dygraph
YamlGenerator
(
Yaml
GeneratorBase
):
class
Dygraph
ForwardAndNodesGenerator
(
GeneratorBase
):
def
__init__
(
self
,
api_yaml_path
,
backward_yaml_path
):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
Yaml
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
self
.
backward_yaml_path
=
backward_yaml_path
self
.
grad_api_dict
=
{}
self
.
forward_definition_str
=
""
self
.
forward_declaration_str
=
""
self
.
forward_definition_str
=
""
self
.
node_declaration_str
=
""
self
.
node_definition_str
=
""
...
...
@@ -1518,6 +1523,7 @@ class DygraphYamlGenerator(YamlGeneratorBase):
self
.
forward_definition_str
+=
function_generator
.
forward_definition_str
+
"
\n
"
self
.
forward_declaration_str
+=
function_generator
.
forward_declaration_str
+
"
\n
"
# Generate Dygraph GradNode Function
while
True
:
next_grad_api_contents
=
self
.
GetBackwardAPIContents
(
backward_api_contents
)
...
...
@@ -1611,20 +1617,23 @@ if __name__ == "__main__":
# Generate per Dygraph API
node_declaration_str
=
""
node_definition_str
=
""
forward_definition_str
=
""
forward_declaration_str
=
""
forward_definition_str
=
""
for
i
in
range
(
len
(
api_yaml_paths
)):
api_yaml_path
=
api_yaml_paths
[
i
]
backward_yaml_path
=
backward_yaml_paths
[
i
]
generator
=
DygraphYamlGenerator
(
api_yaml_path
,
backward_yaml_path
)
generator
=
DygraphForwardAndNodesGenerator
(
api_yaml_path
,
backward_yaml_path
)
generator
.
run
()
node_declaration_str
+=
generator
.
node_declaration_str
+
"
\n
"
node_definition_str
+=
generator
.
node_definition_str
+
"
\n
"
forward_definition_str
+=
generator
.
forward_definition_str
+
"
\n
"
forward_declaration_str
+=
generator
.
forward_declaration_str
+
"
\n
"
forward_definition_str
+=
generator
.
forward_definition_str
+
"
\n
"
# Generate Files
nodes_h_path
=
args
.
nodes_h_path
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
b9342a80
...
...
@@ -15,7 +15,7 @@
import
os
import
argparse
import
logging
from
codegen_utils
import
FunctionGeneratorBase
,
Yaml
GeneratorBase
from
codegen_utils
import
FunctionGeneratorBase
,
GeneratorBase
from
codegen_utils
import
yaml_types_mapping
from
codegen_utils
import
ReadFwdFile
,
IsVectorTensorType
,
GetForwardFunctionName
from
codegen_utils
import
ParseYamlForward
,
GetInplacedFunctionName
...
...
@@ -100,6 +100,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
// Set Device ID
{}
// Call dygraph function
decltype({}({})) out = {}({});
PyEval_RestoreThread(tstate);
...
...
@@ -341,6 +342,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Generate Record Event for performance profiling
pythonc_record_event_str
=
RECORD_EVENT_TEMPLATE
.
format
(
"pythonc_record_event"
,
forward_api_name
,
"pybind_imperative_func"
)
# Generate Python-C Function Definetion
self
.
python_c_function_str
=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
forward_api_name
,
pythonc_record_event_str
,
forward_api_name
,
get_eager_tensor_str
,
parse_attributes_str
,
set_device_str
,
...
...
@@ -350,6 +353,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
# Set prefix of forward_api_name to avoid conflicts
prefix
=
self
.
namespace
.
strip
(
"::"
)
forward_api_name_prefix
=
""
if
prefix
==
""
else
prefix
+
"_"
# Generate Python-C Function Registration
self
.
python_c_function_reg_str
=
PYTHON_C_FUNCTION_REG_TEMPLATE
.
format
(
forward_api_name_prefix
,
forward_api_name
,
namespace
,
...
...
@@ -376,6 +380,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_forward_api_name
,
inplace_output
)
break
# Generate Python-C Function Definetion
self
.
python_c_function_str
+=
PYTHON_C_FUNCTION_TEMPLATE
.
format
(
inplaced_forward_api_name
,
pythonc_record_event_str
,
inplaced_forward_api_name
,
get_eager_tensor_str
,
...
...
@@ -414,17 +419,17 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return
True
class
PythonC
YamlGenerator
(
Yaml
GeneratorBase
):
class
PythonC
Generator
(
GeneratorBase
):
def
__init__
(
self
,
path
):
# Parent members:
# self.namespace
# self.api_yaml_path
# self.forward_api_list
Yaml
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
GeneratorBase
.
__init__
(
self
,
api_yaml_path
)
# Generated Result
self
.
python_c_functions_reg_str
=
""
self
.
python_c_functions_str
=
""
self
.
python_c_functions_reg_str
=
""
def
GeneratePythonCFunctions
(
self
):
namespace
=
self
.
namespace
...
...
@@ -436,8 +441,8 @@ class PythonCYamlGenerator(YamlGeneratorBase):
status
=
f_generator
.
run
()
if
status
==
True
:
self
.
python_c_functions_reg_str
+=
f_generator
.
python_c_function_reg_str
+
",
\n
"
self
.
python_c_functions_str
+=
f_generator
.
python_c_function_str
+
"
\n
"
self
.
python_c_functions_reg_str
+=
f_generator
.
python_c_function_reg_str
+
",
\n
"
def
AttachNamespace
(
self
):
namespace
=
self
.
namespace
...
...
@@ -509,11 +514,11 @@ if __name__ == "__main__":
for
i
in
range
(
len
(
api_yaml_paths
)):
api_yaml_path
=
api_yaml_paths
[
i
]
y_generator
=
PythonCYaml
Generator
(
api_yaml_path
)
y
_generator
.
run
()
py_c_generator
=
PythonC
Generator
(
api_yaml_path
)
py_c
_generator
.
run
()
generated_python_c_functions
+=
y
_generator
.
python_c_functions_str
+
"
\n
"
generated_python_c_registration
+=
y
_generator
.
python_c_functions_reg_str
+
"
\n
"
generated_python_c_functions
+=
py_c
_generator
.
python_c_functions_str
+
"
\n
"
generated_python_c_registration
+=
py_c
_generator
.
python_c_functions_reg_str
+
"
\n
"
python_c_str
=
GeneratePythonCWrappers
(
generated_python_c_functions
,
generated_python_c_registration
)
...
...
python/paddle/utils/code_gen/api_base.py
浏览文件 @
b9342a80
...
...
@@ -434,7 +434,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
vars_list
=
kernel
[
'data_type'
].
split
(
','
)
assert
len
(
vars_list
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows
2
, but received
{
len
(
vars_list
)
}
."
)
==
1
,
f
"
{
api
}
api: The number of params to set data_type only allows
1
, but received
{
len
(
vars_list
)
}
."
kernel_select_code
=
kernel_select_code
+
f
"""
kernel_data_type = ParseDataType(
{
vars_list
[
0
].
strip
()
}
);
"""
...
...
@@ -837,10 +837,10 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
return
api_code
else
:
inv
e
ke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inv
e
ke_func_name
in
self
.
attrs
[
'names'
]:
inv
o
ke_func_name
=
self
.
invoke
.
split
(
'('
)[
0
].
strip
()
if
inv
o
ke_func_name
in
self
.
attrs
[
'names'
]:
# Adjust the param whose name is same with api invoked.
pattern
=
r
'\W'
+
inv
e
ke_func_name
+
'[^A-Za-z0-9_(]'
pattern
=
r
'\W'
+
inv
o
ke_func_name
+
'[^A-Za-z0-9_(]'
def
adjust_name
(
matched
):
matched_str
=
matched
.
group
()
...
...
python/paddle/utils/code_gen/backward_api_gen.py
浏览文件 @
b9342a80
...
...
@@ -172,8 +172,8 @@ class BackwardAPI(BaseAPI):
return
kernel_output
,
output_names
,
output_create
def
gene_invoke_code
(
self
,
invoke_code
,
params_code
):
inv
e
ke_func_name
=
invoke_code
.
split
(
'('
)[
0
].
strip
()
if
inv
eke_func_name
.
endswith
(
'_grad'
)
or
inve
ke_func_name
.
endswith
(
inv
o
ke_func_name
=
invoke_code
.
split
(
'('
)[
0
].
strip
()
if
inv
oke_func_name
.
endswith
(
'_grad'
)
or
invo
ke_func_name
.
endswith
(
'_grad_impl'
):
return
f
"""
PADDLE_API
{
self
.
get_return_type
()
}
{
self
.
api
}
(
{
params_code
}
) {{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录