From 370864ddca4d1f54120dd1feefbfcfffe2919d82 Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Wed, 1 Dec 2021 11:30:40 +0800 Subject: [PATCH] optimizer __call__ to make dygraph faster (#37713) * optimizer __call__ to make dygraph faster * fix return type --- python/paddle/fluid/dygraph/layers.py | 68 +++++++++++++++------------ 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 8ff960a90ea..662e233bd40 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -881,41 +881,49 @@ class Layer(core.Layer): def _build_once(self, *args, **kwargs): pass + def _dygraph_call_func(self, *inputs, **kwargs): + for forward_pre_hook in self._forward_pre_hooks.values(): + hook_result = forward_pre_hook(self, inputs) + if hook_result is not None: + if not isinstance(hook_result, tuple): + hook_result = (hook_result, ) + inputs = hook_result + + if not self._built: + with program_desc_tracing_guard(False): + self._build_once(*inputs, **kwargs) + + # TODO(liuyuhui) Only xpu broadcast parameters here. + # The other device is to call _sync_params_buffers in DataParallel + # to realize the parameter synchronization among multiply cards. + if parallel_helper._is_data_parallel_mode( + ) and paddle.is_compiled_with_xpu(): + parallel_helper._broadcast_parameters( + self._parameters.values()) + + self._built = True + + outputs = self.forward(*inputs, **kwargs) + + for forward_post_hook in self._forward_post_hooks.values(): + hook_result = forward_post_hook(self, inputs, outputs) + if hook_result is not None: + outputs = hook_result + + return outputs + def __call__(self, *inputs, **kwargs): # NOTE(Aurelius84): Why we still need param_guard here? # In case of ControlFlow, true_fn and false_fn will contain # parameters that may not trigger logic of `Operator` to create # them. we add this to make sure all parameters is available. - with param_guard(self._parameters), param_guard(self._buffers): - for forward_pre_hook in self._forward_pre_hooks.values(): - hook_result = forward_pre_hook(self, inputs) - if hook_result is not None: - if not isinstance(hook_result, tuple): - hook_result = (hook_result, ) - inputs = hook_result - - if not self._built: - with program_desc_tracing_guard(False): - self._build_once(*inputs, **kwargs) - - # TODO(liuyuhui) Only xpu broadcast parameters here. - # The other device is to call _sync_params_buffers in DataParallel - # to realize the parameter synchronization among multiply cards. - if parallel_helper._is_data_parallel_mode( - ) and paddle.is_compiled_with_xpu(): - parallel_helper._broadcast_parameters( - self._parameters.values()) - - self._built = True - - outputs = self.forward(*inputs, **kwargs) - - for forward_post_hook in self._forward_post_hooks.values(): - hook_result = forward_post_hook(self, inputs, outputs) - if hook_result is not None: - outputs = hook_result - - return outputs + from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode + + if in_declarative_mode() and not framework.in_dygraph_mode(): + with param_guard(self._parameters), param_guard(self._buffers): + return self._dygraph_call_func(*inputs, **kwargs) + else: + return self._dygraph_call_func(*inputs, **kwargs) def forward(self, *inputs, **kwargs): """ -- GitLab