提交 3f916bdd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3340 modify device id

Merge pull request !3340 from changzherui/mod_device_id
......@@ -45,6 +45,20 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
{"ge_only", kMsBackendGeOnly},
{"vm_prior", kMsBackendVmPrior}};
bool IsCloudTransDeviceId() {
auto deploy_mode = common::GetEnv("DEPLOY_MODE");
if (deploy_mode.empty() || deploy_mode != "1") {
return false;
}
auto rank_size = common::GetEnv("RANK_SIZE");
if (rank_size.empty() || rank_size != "1") {
return false;
}
return true;
}
MsContext::MsContext(const std::string &policy, const std::string &target) {
save_graphs_flag_ = false;
save_graphs_path_ = ".";
......@@ -63,6 +77,12 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
} else {
device_id_ = 0;
}
physics_id_ = device_id_;
if (IsCloudTransDeviceId()) {
device_id_ = 0;
}
backend_policy_ = policy_map_[policy];
device_target_ = target;
execution_mode_ = kPynativeMode;
......@@ -147,6 +167,13 @@ bool MsContext::set_device_target(const std::string &target) {
bool MsContext::set_device_id(uint32_t device_id) {
device_id_ = device_id;
MS_LOG(INFO) << "ms set context device id:" << device_id;
physics_id_ = device_id_;
if (IsCloudTransDeviceId()) {
device_id_ = 0;
}
MS_LOG(INFO) << "ms set context logic id:" << device_id;
return true;
}
......@@ -166,7 +193,8 @@ bool MsContext::OpenTsd() {
unsigned int device_id;
unsigned int rank_size = 1;
device_id = device_id_;
device_id = physics_id_;
MS_LOG(INFO) << "Open and init tsd, device = " << device_id << ".";
auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) {
......
......@@ -172,6 +172,7 @@ class MsContext {
MsBackendPolicy backend_policy_;
std::string device_target_;
uint32_t device_id_;
uint32_t physics_id_;
int execution_mode_;
bool enable_pynative_infer_;
bool enable_pynative_hook_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册