diff --git a/mindinsight/profiler/profiling.py b/mindinsight/profiler/profiling.py index c0ac6d716e520fbad78886011a2957dd5d710498..174ded8ae44b14d9c554e50a6123d9d733662491 100644 --- a/mindinsight/profiler/profiling.py +++ b/mindinsight/profiler/profiling.py @@ -65,21 +65,29 @@ class Profiler: def __init__(self, subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data', optypes_to_deal='', optypes_not_deal='Variable', job_id=""): - # get device_id + # get device_id and device_target + device_target = "" try: import mindspore.context as context - dev_id = context.get_context("device_id") + dev_id = str(context.get_context("device_id")) + device_target = context.get_context("device_target") except ImportError: logger.error("Profiling: fail to import context from mindspore.") except ValueError as err: logger.error("Profiling: fail to get context %s", err.message) if not dev_id: - dev_id = os.getenv('DEVICE_ID') + dev_id = str(os.getenv('DEVICE_ID')) if not dev_id: dev_id = "0" logger.error("Fail to get DEVICE_ID, use 0 instead.") + if device_target and device_target != "Davinci" \ + and device_target != "Ascend": + msg = ("Profiling: unsupport backend: %s" \ + % device_target) + raise RuntimeError(msg) + self._dev_id = dev_id self._container_path = os.path.join(self._base_profiling_container_path, dev_id) data_path = os.path.join(self._container_path, "data")