From 96c95b3d6b03ab0c489cbc06b07ac44cab8f80c4 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 18 Apr 2022 19:14:50 +0800 Subject: [PATCH] [Eager] Add _fallback_legacy_dygraph for npu/xpu/rocm (#41774) (#41898) * [Eager] add _fallback_legacy_dygraph for npu/xpu/rocm * fix import --- python/paddle/fluid/framework.py | 43 ++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a329610eeae..5dab39a35d4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -166,6 +166,40 @@ def _in_eager_without_dygraph_check(): return _in_eager_mode_ +# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but +# only GPU/CPU. Remove this after we improve this feature. +_is_first_import_ = True + + +def _fallback_legacy_dygraph(): + global _in_eager_mode_ + global _is_first_import_ + need_fallback = False + # Only enable eager on CPU/GPU + is_not_support = core.is_compiled_with_xpu() or core.is_compiled_with_npu( + ) or core.is_compiled_with_ipu() or core.is_compiled_with_mlu( + ) or core.is_compiled_with_rocm() + + if _in_eager_mode_ and is_not_support: + # switch into legacy dygraph mode + warnings.warn( + "We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. " + ) + _in_eager_mode_ = False + if not _is_first_import_: + _enable_legacy_dygraph() + need_fallback = True + + need_fallback = False + _is_first_import_ = False + + return need_fallback + + +# switch into legacy mode if need while import paddle +_fallback_legacy_dygraph() + + def in_dygraph_mode(): """ @@ -206,11 +240,16 @@ def _non_static_mode(): @signature_safe_contextmanager def _test_eager_guard(place=None): - _disable_legacy_dygraph() + # FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but + # only GPU/CPU. Remove this after we improve this feature. + already_fallback = _fallback_legacy_dygraph() + if not already_fallback: + _disable_legacy_dygraph() try: yield finally: - _enable_legacy_dygraph() + if not already_fallback: + _enable_legacy_dygraph() global_ipu_index = None -- GitLab