提交 41d039cf 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

fix ParallelDesc Constructor bug (#4118)

* fix parallel

* optimize maybeinit

* add test file

Former-commit-id: e6aaee0c
上级 97abecce
...@@ -80,20 +80,17 @@ Maybe<ParallelDesc> ParallelDesc::New(int64_t symbol_id, const ParallelConf& par ...@@ -80,20 +80,17 @@ Maybe<ParallelDesc> ParallelDesc::New(int64_t symbol_id, const ParallelConf& par
Maybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) { Maybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) {
parallel_conf_ = user_conf; parallel_conf_ = user_conf;
cfg_parallel_conf_.reset(new cfg::ParallelConf(user_conf)); cfg_parallel_conf_.reset(new cfg::ParallelConf(user_conf));
HashSet<int64_t> machine_id_set;
device_type_ = DeviceType::kInvalidDevice; device_type_ = DeviceType::kInvalidDevice;
const std::string& device_tag = parallel_conf_.device_tag(); const std::string& device_tag = parallel_conf_.device_tag();
DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag)); DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag));
CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type); CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type);
device_type_ = device_type; device_type_ = device_type;
machine_id2sorted_dev_phy_ids_ =
std::make_shared<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>();
for (const std::string& device_name : parallel_conf_.device_name()) { for (const std::string& device_name : parallel_conf_.device_name()) {
int64_t mchn_id; int64_t mchn_id;
std::string device_id_str; std::string device_id_str;
JUST(ParseDeviceNameConf(device_name, &mchn_id, &device_id_str)); JUST(ParseDeviceNameConf(device_name, &mchn_id, &device_id_str));
machine_id_set.insert(mchn_id);
if (machine_id_set.find(mchn_id) == machine_id_set.end()) {
sorted_machine_ids_.push_back(mchn_id);
}
int64_t minus_pos = device_id_str.find("-"); int64_t minus_pos = device_id_str.find("-");
if (minus_pos == std::string::npos) { if (minus_pos == std::string::npos) {
device_id_str = device_id_str + "-" + device_id_str; device_id_str = device_id_str + "-" + device_id_str;
...@@ -101,10 +98,10 @@ Maybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) { ...@@ -101,10 +98,10 @@ Maybe<void> ParallelDesc::MaybeInit(const ParallelConf& user_conf) {
} }
int64_t min_id = oneflow_cast<int64_t>(device_id_str.substr(0, minus_pos)); int64_t min_id = oneflow_cast<int64_t>(device_id_str.substr(0, minus_pos));
int64_t max_id = oneflow_cast<int64_t>(device_id_str.substr(minus_pos + 1)); int64_t max_id = oneflow_cast<int64_t>(device_id_str.substr(minus_pos + 1));
machine_id2sorted_dev_phy_ids_ =
std::make_shared<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>();
(*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared<std::vector<int64_t>>();
CHECK_LE_OR_RETURN(min_id, max_id); CHECK_LE_OR_RETURN(min_id, max_id);
if (!(*machine_id2sorted_dev_phy_ids_)[mchn_id]) {
(*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared<std::vector<int64_t>>();
}
for (int64_t dev_phy_id = min_id; dev_phy_id <= max_id; ++dev_phy_id) { for (int64_t dev_phy_id = min_id; dev_phy_id <= max_id; ++dev_phy_id) {
(*machine_id2sorted_dev_phy_ids_)[mchn_id]->push_back(dev_phy_id); (*machine_id2sorted_dev_phy_ids_)[mchn_id]->push_back(dev_phy_id);
} }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/job/parallel_desc.h"
namespace oneflow {
namespace test {
TEST(parallel_desc, continuous_1n4d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-3");
ParallelDesc parallel_desc(parallel_conf);
ASSERT_EQ(parallel_desc.device_tag(), "cpu");
ASSERT_EQ(parallel_desc.parallel_num(), 4);
}
TEST(parallel_desc, discrete_1n4d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-1");
parallel_conf.add_device_name("0:2-3");
ParallelDesc parallel_desc(parallel_conf);
ASSERT_EQ(parallel_desc.device_tag(), "cpu");
ASSERT_EQ(parallel_desc.parallel_num(), 4);
}
TEST(parallel_desc, continuous_2n8d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-3");
parallel_conf.add_device_name("1:0-3");
ParallelDesc parallel_desc(parallel_conf);
ASSERT_EQ(parallel_desc.device_tag(), "cpu");
ASSERT_EQ(parallel_desc.parallel_num(), 8);
}
TEST(parallel_desc, discrete_2n8d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-1");
parallel_conf.add_device_name("0:2-3");
parallel_conf.add_device_name("1:0-1");
parallel_conf.add_device_name("1:2-3");
ParallelDesc parallel_desc(parallel_conf);
ASSERT_EQ(parallel_desc.device_tag(), "cpu");
ASSERT_EQ(parallel_desc.parallel_num(), 8);
}
} // namespace test
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册