diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 9f283319a75a0237f2ee9bc0a48b8440d1d4f418..6c3986e42230c560fa10501180d771057c142a25 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -45,6 +45,20 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}}; +bool IsCloudTransDeviceId() { + auto deploy_mode = common::GetEnv("DEPLOY_MODE"); + if (deploy_mode.empty() || deploy_mode != "1") { + return false; + } + + auto rank_size = common::GetEnv("RANK_SIZE"); + if (rank_size.empty() || rank_size != "1") { + return false; + } + + return true; +} + MsContext::MsContext(const std::string &policy, const std::string &target) { save_graphs_flag_ = false; save_graphs_path_ = "."; @@ -63,6 +77,12 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { } else { device_id_ = 0; } + + physics_id_ = device_id_; + if (IsCloudTransDeviceId()) { + device_id_ = 0; + } + backend_policy_ = policy_map_[policy]; device_target_ = target; execution_mode_ = kPynativeMode; @@ -147,6 +167,13 @@ bool MsContext::set_device_target(const std::string &target) { bool MsContext::set_device_id(uint32_t device_id) { device_id_ = device_id; MS_LOG(INFO) << "ms set context device id:" << device_id; + + physics_id_ = device_id_; + if (IsCloudTransDeviceId()) { + device_id_ = 0; + } + MS_LOG(INFO) << "ms set context logic id:" << device_id; + return true; } @@ -166,7 +193,8 @@ bool MsContext::OpenTsd() { unsigned int device_id; unsigned int rank_size = 1; - device_id = device_id_; + device_id = physics_id_; + MS_LOG(INFO) << "Open and init tsd, device = " << device_id << "."; auto rank_size_env = common::GetEnv("RANK_SIZE"); if (rank_size_env.empty()) { diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 9afe1fa5aaca61ba10849ba8050558d71d1b7ee0..9187fb0b753c3d4c85e399b257b885483f798f04 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -172,6 +172,7 @@ class MsContext { MsBackendPolicy backend_policy_; std::string device_target_; uint32_t device_id_; + uint32_t physics_id_; int execution_mode_; bool enable_pynative_infer_; bool enable_pynative_hook_;