未验证 提交 bfb85779 编写于 作者: J Jiabin Yang 提交者: GitHub

optimize dygraph performance with move runtime import to begining (#37759)

* optimize dygraph probl

* refine code

* fix convert dtype error

* fix import datafeeder error
上级 cfd6a8fc
...@@ -93,10 +93,10 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): ...@@ -93,10 +93,10 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
if in_dygraph_mode(): if in_dygraph_mode():
return return
from .dygraph.dygraph_to_static.program_translator import in_declarative_mode
# NOTE: `in_declarative_mode` is used to determined whether this op is called under # NOTE: `in_declarative_mode` is used to determined whether this op is called under
# @declarative in transformation from dygrah to static layer. We add VarBase in # @declarative in transformation from dygrah to static layer. We add VarBase in
# expected_type to skip checking because varBase may be created and used in unusual way. # expected_type to skip checking because varBase may be created and used in unusual way.
from .dygraph.base import in_declarative_mode
# Need a better design to be fix this. # Need a better design to be fix this.
if in_declarative_mode(): if in_declarative_mode():
if not isinstance(expected_type, tuple): if not isinstance(expected_type, tuple):
......
...@@ -33,6 +33,17 @@ __all__ = [ ...@@ -33,6 +33,17 @@ __all__ = [
'enabled', 'to_variable' 'enabled', 'to_variable'
] ]
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return _in_declarative_mode_
def _switch_to_static_graph_(func): def _switch_to_static_graph_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
...@@ -45,6 +56,16 @@ def _switch_to_static_graph_(func): ...@@ -45,6 +56,16 @@ def _switch_to_static_graph_(func):
switch_to_static_graph = wrap_decorator(_switch_to_static_graph_) switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
@signature_safe_contextmanager @signature_safe_contextmanager
def program_desc_tracing_guard(enable): def program_desc_tracing_guard(enable):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -63,7 +84,6 @@ _functional_dygraph_context_manager = None ...@@ -63,7 +84,6 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not framework.in_dygraph_mode() and parameters: if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy() origin_parameters = parameters.copy()
......
...@@ -573,28 +573,6 @@ class StaticFunction(object): ...@@ -573,28 +573,6 @@ class StaticFunction(object):
return self._function_spec return self._function_spec
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False
def in_declarative_mode():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return _in_declarative_mode_
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
def _verify_init_in_dynamic_mode(class_instance): def _verify_init_in_dynamic_mode(class_instance):
""" """
Verifies the instance is initialized in dynamic mode. Verifies the instance is initialized in dynamic mode.
...@@ -658,6 +636,7 @@ class ConcreteProgram(object): ...@@ -658,6 +636,7 @@ class ConcreteProgram(object):
startup_program.random_seed = framework.default_startup_program( startup_program.random_seed = framework.default_startup_program(
).random_seed ).random_seed
from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True): with _switch_declarative_mode_guard_(is_declarative=True):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
......
...@@ -31,7 +31,7 @@ from .. import unique_name ...@@ -31,7 +31,7 @@ from .. import unique_name
from paddle.fluid import core from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard from .base import program_desc_tracing_guard, param_guard, in_declarative_mode
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
...@@ -917,7 +917,6 @@ class Layer(object): ...@@ -917,7 +917,6 @@ class Layer(object):
# In case of ControlFlow, true_fn and false_fn will contain # In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create # parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available. # them. we add this to make sure all parameters is available.
from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode
if in_declarative_mode() and not framework.in_dygraph_mode(): if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers): with param_guard(self._parameters), param_guard(self._buffers):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册