未验证 提交 96c95b3d 编写于 作者: A Aurelius84 提交者: GitHub

[Eager] Add _fallback_legacy_dygraph for npu/xpu/rocm (#41774) (#41898)

* [Eager] add _fallback_legacy_dygraph for npu/xpu/rocm

* fix import
上级 f92dbfb7
......@@ -166,6 +166,40 @@ def _in_eager_without_dygraph_check():
return _in_eager_mode_
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
# only GPU/CPU. Remove this after we improve this feature.
_is_first_import_ = True
def _fallback_legacy_dygraph():
global _in_eager_mode_
global _is_first_import_
need_fallback = False
# Only enable eager on CPU/GPU
is_not_support = core.is_compiled_with_xpu() or core.is_compiled_with_npu(
) or core.is_compiled_with_ipu() or core.is_compiled_with_mlu(
) or core.is_compiled_with_rocm()
if _in_eager_mode_ and is_not_support:
# switch into legacy dygraph mode
warnings.warn(
"We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. "
)
_in_eager_mode_ = False
if not _is_first_import_:
_enable_legacy_dygraph()
need_fallback = True
need_fallback = False
_is_first_import_ = False
return need_fallback
# switch into legacy mode if need while import paddle
_fallback_legacy_dygraph()
def in_dygraph_mode():
"""
......@@ -206,11 +240,16 @@ def _non_static_mode():
@signature_safe_contextmanager
def _test_eager_guard(place=None):
_disable_legacy_dygraph()
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
# only GPU/CPU. Remove this after we improve this feature.
already_fallback = _fallback_legacy_dygraph()
if not already_fallback:
_disable_legacy_dygraph()
try:
yield
finally:
_enable_legacy_dygraph()
if not already_fallback:
_enable_legacy_dygraph()
global_ipu_index = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册