diff --git a/mindspore/context.py b/mindspore/context.py index 51418d3965ca1fc294e157071bf6ffc95bb774c8..0de6084caf520674a7dd4fee44657bb61e5fbfd2 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -176,10 +176,7 @@ class _Context: self._context_switches.push(True, None) else: if self.enable_debug_runtime: - if self.device_target == "CPU": - self.set_backend_policy("vm") - else: - self.set_backend_policy("ge") + self.set_backend_policy("ge") self._context_switches.push(False, None) def set_backend_policy(self, policy): @@ -221,6 +218,8 @@ class _Context: success = self._context_handle.set_device_target(target) if not success: raise ValueError("Target device name is invalid!!!") + if self.enable_debug_runtime and self.device_target == "CPU": + self.set_backend_policy("vm") @property def device_id(self): diff --git a/tests/ut/python/pipeline/parse/test_cell_bprop.py b/tests/ut/python/pipeline/parse/test_cell_bprop.py index 7207160cac150f73a0bc426f884e31027cd160ef..e896ddc9ac78d7cf2dee848a755d6a87f979cf65 100644 --- a/tests/ut/python/pipeline/parse/test_cell_bprop.py +++ b/tests/ut/python/pipeline/parse/test_cell_bprop.py @@ -29,8 +29,7 @@ from .....mindspore_test_framework.utils.bprop_util import bprop def setup_module(module): - context.set_context(device_target="CPU") - context.set_context(mode=context.GRAPH_MODE) + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") def teardown_module(module): context.set_context(device_target="Ascend")