未验证 提交 370864dd 编写于 作者: J Jiabin Yang 提交者: GitHub

optimizer __call__ to make dygraph faster (#37713)

* optimizer __call__ to make dygraph faster

* fix return type
上级 28b43111
......@@ -881,12 +881,7 @@ class Layer(core.Layer):
def _build_once(self, *args, **kwargs):
pass
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with param_guard(self._parameters), param_guard(self._buffers):
def _dygraph_call_func(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
......@@ -917,6 +912,19 @@ class Layer(core.Layer):
return outputs
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers):
return self._dygraph_call_func(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs):
"""
Defines the computation performed at every call.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册