From 41d039cfcd6576919ac884979c56efe3ea4e1636 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Fri, 15 Jan 2021 01:25:32 -0600 Subject: [PATCH] fix ParallelDesc Constructor bug (#4118) * fix parallel * optimize maybeinit * add test file Former-commit-id: e6aaee0cd3b07b6e3b103965054f9da3c54308ae --- oneflow/core/job/parallel_desc.cpp | 13 ++--- oneflow/core/job/parallel_desc_test.cpp | 65 +++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 8 deletions(-) create mode 100644 oneflow/core/job/parallel_desc_test.cpp diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 755bcbf00e..c8e3ebe99a 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -80,20 +80,17 @@ Maybe ParallelDesc::New(int64_t symbol_id, const ParallelConf& par Maybe ParallelDesc::MaybeInit(const ParallelConf& user_conf) { parallel_conf_ = user_conf; cfg_parallel_conf_.reset(new cfg::ParallelConf(user_conf)); - HashSet machine_id_set; device_type_ = DeviceType::kInvalidDevice; const std::string& device_tag = parallel_conf_.device_tag(); DeviceType device_type = JUST(DeviceType4DeviceTag(device_tag)); CHECK_OR_RETURN(device_type_ == DeviceType::kInvalidDevice || device_type_ == device_type); device_type_ = device_type; + machine_id2sorted_dev_phy_ids_ = + std::make_shared>>>(); for (const std::string& device_name : parallel_conf_.device_name()) { int64_t mchn_id; std::string device_id_str; JUST(ParseDeviceNameConf(device_name, &mchn_id, &device_id_str)); - machine_id_set.insert(mchn_id); - if (machine_id_set.find(mchn_id) == machine_id_set.end()) { - sorted_machine_ids_.push_back(mchn_id); - } int64_t minus_pos = device_id_str.find("-"); if (minus_pos == std::string::npos) { device_id_str = device_id_str + "-" + device_id_str; @@ -101,10 +98,10 @@ Maybe ParallelDesc::MaybeInit(const ParallelConf& user_conf) { } int64_t min_id = oneflow_cast(device_id_str.substr(0, minus_pos)); int64_t max_id = oneflow_cast(device_id_str.substr(minus_pos + 1)); - machine_id2sorted_dev_phy_ids_ = - std::make_shared>>>(); - (*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared>(); CHECK_LE_OR_RETURN(min_id, max_id); + if (!(*machine_id2sorted_dev_phy_ids_)[mchn_id]) { + (*machine_id2sorted_dev_phy_ids_)[mchn_id] = std::make_shared>(); + } for (int64_t dev_phy_id = min_id; dev_phy_id <= max_id; ++dev_phy_id) { (*machine_id2sorted_dev_phy_ids_)[mchn_id]->push_back(dev_phy_id); } diff --git a/oneflow/core/job/parallel_desc_test.cpp b/oneflow/core/job/parallel_desc_test.cpp new file mode 100644 index 0000000000..521bdb49f4 --- /dev/null +++ b/oneflow/core/job/parallel_desc_test.cpp @@ -0,0 +1,65 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/util.h" +#include "oneflow/core/job/placement.pb.h" +#include "oneflow/core/job/parallel_desc.h" + +namespace oneflow { +namespace test { + +TEST(parallel_desc, continuous_1n4d) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0-3"); + ParallelDesc parallel_desc(parallel_conf); + ASSERT_EQ(parallel_desc.device_tag(), "cpu"); + ASSERT_EQ(parallel_desc.parallel_num(), 4); +} + +TEST(parallel_desc, discrete_1n4d) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0-1"); + parallel_conf.add_device_name("0:2-3"); + ParallelDesc parallel_desc(parallel_conf); + ASSERT_EQ(parallel_desc.device_tag(), "cpu"); + ASSERT_EQ(parallel_desc.parallel_num(), 4); +} + +TEST(parallel_desc, continuous_2n8d) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0-3"); + parallel_conf.add_device_name("1:0-3"); + ParallelDesc parallel_desc(parallel_conf); + ASSERT_EQ(parallel_desc.device_tag(), "cpu"); + ASSERT_EQ(parallel_desc.parallel_num(), 8); +} + +TEST(parallel_desc, discrete_2n8d) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0-1"); + parallel_conf.add_device_name("0:2-3"); + parallel_conf.add_device_name("1:0-1"); + parallel_conf.add_device_name("1:2-3"); + ParallelDesc parallel_desc(parallel_conf); + ASSERT_EQ(parallel_desc.device_tag(), "cpu"); + ASSERT_EQ(parallel_desc.parallel_num(), 8); +} + +} // namespace test +} // namespace oneflow -- GitLab