未验证 提交 370864dd 编写于 作者: J Jiabin Yang 提交者: GitHub

optimizer __call__ to make dygraph faster (#37713)

* optimizer __call__ to make dygraph faster

* fix return type
上级 28b43111
......@@ -881,41 +881,49 @@ class Layer(core.Layer):
def _build_once(self, *args, **kwargs):
pass
def _dygraph_call_func(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if parallel_helper._is_data_parallel_mode(
) and paddle.is_compiled_with_xpu():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
return outputs
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if parallel_helper._is_data_parallel_mode(
) and paddle.is_compiled_with_xpu():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
return outputs
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers):
return self._dygraph_call_func(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册