diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a329610eeae8353e05eb7bd5d2877abbbff0bafc..5dab39a35d4787a091e0d144e3aa79e053e8f9f7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -166,6 +166,40 @@ def _in_eager_without_dygraph_check(): return _in_eager_mode_ +# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but +# only GPU/CPU. Remove this after we improve this feature. +_is_first_import_ = True + + +def _fallback_legacy_dygraph(): + global _in_eager_mode_ + global _is_first_import_ + need_fallback = False + # Only enable eager on CPU/GPU + is_not_support = core.is_compiled_with_xpu() or core.is_compiled_with_npu( + ) or core.is_compiled_with_ipu() or core.is_compiled_with_mlu( + ) or core.is_compiled_with_rocm() + + if _in_eager_mode_ and is_not_support: + # switch into legacy dygraph mode + warnings.warn( + "We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. " + ) + _in_eager_mode_ = False + if not _is_first_import_: + _enable_legacy_dygraph() + need_fallback = True + + need_fallback = False + _is_first_import_ = False + + return need_fallback + + +# switch into legacy mode if need while import paddle +_fallback_legacy_dygraph() + + def in_dygraph_mode(): """ @@ -206,11 +240,16 @@ def _non_static_mode(): @signature_safe_contextmanager def _test_eager_guard(place=None): - _disable_legacy_dygraph() + # FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but + # only GPU/CPU. Remove this after we improve this feature. + already_fallback = _fallback_legacy_dygraph() + if not already_fallback: + _disable_legacy_dygraph() try: yield finally: - _enable_legacy_dygraph() + if not already_fallback: + _enable_legacy_dygraph() global_ipu_index = None