Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
370864dd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
370864dd
编写于
12月 01, 2021
作者:
J
Jiabin Yang
提交者:
GitHub
12月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimizer __call__ to make dygraph faster (#37713)
* optimizer __call__ to make dygraph faster * fix return type
上级
28b43111
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
38 addition
and
30 deletion
+38
-30
python/paddle/fluid/dygraph/layers.py
python/paddle/fluid/dygraph/layers.py
+38
-30
未找到文件。
python/paddle/fluid/dygraph/layers.py
浏览文件 @
370864dd
...
...
@@ -881,41 +881,49 @@ class Layer(core.Layer):
def
_build_once
(
self
,
*
args
,
**
kwargs
):
pass
def
_dygraph_call_func
(
self
,
*
inputs
,
**
kwargs
):
for
forward_pre_hook
in
self
.
_forward_pre_hooks
.
values
():
hook_result
=
forward_pre_hook
(
self
,
inputs
)
if
hook_result
is
not
None
:
if
not
isinstance
(
hook_result
,
tuple
):
hook_result
=
(
hook_result
,
)
inputs
=
hook_result
if
not
self
.
_built
:
with
program_desc_tracing_guard
(
False
):
self
.
_build_once
(
*
inputs
,
**
kwargs
)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if
parallel_helper
.
_is_data_parallel_mode
(
)
and
paddle
.
is_compiled_with_xpu
():
parallel_helper
.
_broadcast_parameters
(
self
.
_parameters
.
values
())
self
.
_built
=
True
outputs
=
self
.
forward
(
*
inputs
,
**
kwargs
)
for
forward_post_hook
in
self
.
_forward_post_hooks
.
values
():
hook_result
=
forward_post_hook
(
self
,
inputs
,
outputs
)
if
hook_result
is
not
None
:
outputs
=
hook_result
return
outputs
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with
param_guard
(
self
.
_parameters
),
param_guard
(
self
.
_buffers
):
for
forward_pre_hook
in
self
.
_forward_pre_hooks
.
values
():
hook_result
=
forward_pre_hook
(
self
,
inputs
)
if
hook_result
is
not
None
:
if
not
isinstance
(
hook_result
,
tuple
):
hook_result
=
(
hook_result
,
)
inputs
=
hook_result
if
not
self
.
_built
:
with
program_desc_tracing_guard
(
False
):
self
.
_build_once
(
*
inputs
,
**
kwargs
)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if
parallel_helper
.
_is_data_parallel_mode
(
)
and
paddle
.
is_compiled_with_xpu
():
parallel_helper
.
_broadcast_parameters
(
self
.
_parameters
.
values
())
self
.
_built
=
True
outputs
=
self
.
forward
(
*
inputs
,
**
kwargs
)
for
forward_post_hook
in
self
.
_forward_post_hooks
.
values
():
hook_result
=
forward_post_hook
(
self
,
inputs
,
outputs
)
if
hook_result
is
not
None
:
outputs
=
hook_result
return
outputs
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
in_declarative_mode
if
in_declarative_mode
()
and
not
framework
.
in_dygraph_mode
():
with
param_guard
(
self
.
_parameters
),
param_guard
(
self
.
_buffers
):
return
self
.
_dygraph_call_func
(
*
inputs
,
**
kwargs
)
else
:
return
self
.
_dygraph_call_func
(
*
inputs
,
**
kwargs
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录