Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
25591674
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
25591674
编写于
3月 27, 2022
作者:
P
pangyoki
提交者:
GitHub
3月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix inplace bug in final_state eager_gen (#40979)
* fix inplace bug in final_state eager_gen * fix python_c_gen
上级
52f07ab4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
11 deletion
+22
-11
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+11
-9
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
...auto_code_generator/final_state_generator/python_c_gen.py
+11
-2
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
25591674
...
@@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
...
@@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
f
"auto NEW_
{
name
}
= (
{
name
}
.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(
\"
{
name
}
\"
, *(
{
name
}
.get_ptr()), amp_dst_dtype, op_name)) :
{
name
}
;
\n
"
f
"auto NEW_
{
name
}
= (
{
name
}
.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(
\"
{
name
}
\"
, *(
{
name
}
.get_ptr()), amp_dst_dtype, op_name)) :
{
name
}
;
\n
"
)
)
else
:
else
:
if
inplace_map
and
name
in
inplace_map
.
keys
():
if
is_inplaced
and
inplace_map
and
name
in
inplace_map
.
keys
(
):
arg_str
=
f
"paddle::experimental::Tensor&
{
name
}
"
arg_str
=
f
"paddle::experimental::Tensor&
{
name
}
"
amp_tensors_vector_list
.
append
(
f
"{{
{
name
}
}}"
)
amp_tensors_vector_list
.
append
(
f
"{{
{
name
}
}}"
)
amp_autocast_list
.
append
(
amp_autocast_list
.
append
(
...
@@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
...
@@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
returns_str
=
", "
.
join
(
returns_list
)
returns_str
=
", "
.
join
(
returns_list
)
returns_str
=
f
"std::make_tuple(
{
returns_str
}
)"
returns_str
=
f
"std::make_tuple(
{
returns_str
}
)"
self
.
GenerateNodeCreationCodes
(
forward_call_str
)
self
.
GenerateNodeCreationCodes
(
forward_call_str
,
is_inplaced
)
node_creation_str
=
self
.
node_creation_str
node_creation_str
=
self
.
node_creation_str
dygraph_event_str
=
f
"paddle::platform::RecordEvent dygraph_entrance_record_event(
\"
{
forward_api_name
}
dygraph
\"
, paddle::platform::TracerEventType::Operator, 1);"
dygraph_event_str
=
f
"paddle::platform::RecordEvent dygraph_entrance_record_event(
\"
{
forward_api_name
}
dygraph
\"
, paddle::platform::TracerEventType::Operator, 1);"
...
@@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
...
@@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
logging
.
info
(
logging
.
info
(
f
"Generated Forward Declaration:
{
self
.
forward_declaration_str
}
"
)
f
"Generated Forward Declaration:
{
self
.
forward_declaration_str
}
"
)
def
GenerateNodeCreationCodes
(
self
,
forward_call_str
):
def
GenerateNodeCreationCodes
(
self
,
forward_call_str
,
is_inplaced
):
forward_api_name
=
self
.
forward_api_name
forward_api_name
=
self
.
forward_api_name
forward_inputs_position_map
=
self
.
forward_inputs_position_map
forward_inputs_position_map
=
self
.
forward_inputs_position_map
forward_outputs_position_map
=
self
.
forward_outputs_position_map
forward_outputs_position_map
=
self
.
forward_outputs_position_map
...
@@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
...
@@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
# Check Inplace
# Check Inplace
check_inplace_str
=
""
check_inplace_str
=
""
bump_inplace_version_str
=
""
bump_inplace_version_str
=
""
for
inplace_name
in
inplace_map
.
keys
():
if
is_inplaced
:
inplace_autograd_meta_name
=
GetAutoGradMetaName
(
inplace_name
)
for
inplace_name
in
inplace_map
.
keys
():
check_inplace_str
+=
CHECK_INPLACE_TEMPLATE
.
format
(
inplace_autograd_meta_name
=
GetAutoGradMetaName
(
inplace_name
)
inplace_name
,
inplace_autograd_meta_name
)
check_inplace_str
+=
CHECK_INPLACE_TEMPLATE
.
format
(
bump_inplace_version_str
+=
BUMP_INPLACE_VERSION_TEMPLATE
.
format
(
inplace_name
,
inplace_autograd_meta_name
)
inplace_name
,
inplace_name
)
bump_inplace_version_str
+=
BUMP_INPLACE_VERSION_TEMPLATE
.
format
(
inplace_name
,
inplace_name
)
# Node Construction
# Node Construction
num_backward_inputs
=
len
(
forward_outputs_position_map
.
keys
())
num_backward_inputs
=
len
(
forward_outputs_position_map
.
keys
())
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py
浏览文件 @
25591674
...
@@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
...
@@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name
,
namespace
,
forward_api_name
,
forward_api_name
)
forward_api_name
,
namespace
,
forward_api_name
,
forward_api_name
)
if
len
(
inplace_map
)
>
0
:
if
len
(
inplace_map
)
>
0
:
inplaced_forward_api_name
=
GetInplacedFunctionName
(
self
.
forward_api_name
)
assert
len
(
assert
len
(
inplace_map
inplace_map
)
==
1
,
f
"size of inplace_map must be 1, but inplace_map of
\"
{
forward_api_name
}
\"
op got
{
len
(
inplace_map
)
}
"
)
==
1
,
f
"size of inplace_map must be 1, but inplace_map of
\"
{
forward_api_name
}
\"
op got
{
len
(
inplace_map
)
}
"
inplaced_forward_api_name
=
GetInplacedFunctionName
(
self
.
forward_api_name
)
# Generate Python-C Function Definitions
if
is_forward_only
:
fwd_function_name
=
FUNCTION_NAME_TEMPLATE
.
format
(
"paddle::experimental::"
,
namespace
,
inplaced_forward_api_name
)
elif
len
(
inplace_map
)
>
0
:
fwd_function_name
=
FUNCTION_NAME_TEMPLATE
.
format
(
"::"
,
namespace
,
GetForwardFunctionName
(
inplaced_forward_api_name
))
for
inplace_input
,
inplace_output
in
inplace_map
.
items
():
for
inplace_input
,
inplace_output
in
inplace_map
.
items
():
return_str
=
RETURN_INPLACE_PYOBJECT_TEMPLATE
.
format
(
return_str
=
RETURN_INPLACE_PYOBJECT_TEMPLATE
.
format
(
inplaced_forward_api_name
,
inplace_input
,
inplaced_forward_api_name
,
inplace_input
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录