未验证 提交 75d58371 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Mig parallel conf util (#4168)

* mig parallel_conf_util

* mig BuildInitialScope BuildScopeWithNewParallelDesc BuildScopeWithNewParallelConf

* add test of GetDeviceTagAndMachineDeviceIds

* fix BuildScopeWithNewParallelDesc input type error

* use TRY
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 ac8e472e
......@@ -112,6 +112,28 @@ void ReplaceMirrored(const std::shared_ptr<InstructionsBuilder>& x,
return x->ReplaceMirrored(parallel_desc_sym, lhs_objects, rhs_objects).GetOrThrow();
}
std::shared_ptr<Scope> BuildInitialScope(const std::shared_ptr<InstructionsBuilder>& x,
int64_t session_id,
const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::string& device_tag,
const std::vector<std::string>& machine_device_ids,
bool is_mirrored) {
return x->BuildInitialScope(session_id, job_conf, device_tag, machine_device_ids, is_mirrored)
.GetPtrOrThrow();
}
std::shared_ptr<Scope> BuildScopeWithNewParallelDesc(
const std::shared_ptr<InstructionsBuilder>& x, const std::shared_ptr<Scope>& scope,
const std::string& device_tag, const std::vector<std::string>& machine_device_ids) {
return x->BuildScopeWithNewParallelDesc(scope, device_tag, machine_device_ids).GetPtrOrThrow();
}
std::shared_ptr<Scope> BuildScopeWithNewParallelConf(
const std::shared_ptr<InstructionsBuilder>& x, const std::shared_ptr<Scope>& scope,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) {
return x->BuildScopeWithNewParallelConf(scope, parallel_conf).GetPtrOrThrow();
}
std::shared_ptr<Scope> BuildScopeWithNewIsMirrored(const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<Scope>& scope,
bool is_mirrored) {
......@@ -195,6 +217,9 @@ ONEFLOW_API_PYBIND11_MODULE("deprecated", m) {
.def("UnpackLogicalBlobToPhysicalBlobs", &UnpackLogicalBlobToPhysicalBlobs)
.def("MakeReferenceBlobObject", &MakeReferenceBlobObject)
.def("ReplaceMirrored", &ReplaceMirrored)
.def("BuildInitialScope", &BuildInitialScope)
.def("BuildScopeWithNewParallelDesc", &BuildScopeWithNewParallelDesc)
.def("BuildScopeWithNewParallelConf", &BuildScopeWithNewParallelConf)
.def("BuildScopeWithNewIsMirrored", &BuildScopeWithNewIsMirrored)
.def("BuildScopeWithNewScopeName", &BuildScopeWithNewScopeName)
.def("BuildScopeByProtoSetter", &BuildScopeByProtoSetter)
......
"""
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,32 +12,32 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
import re
import oneflow_api.oneflow.core.job.placement as placement_cfg
def GetDeviceTagAndMachineDeviceIds(parallel_conf):
machine_device_ids = []
for device_name in list(parallel_conf.device_name()):
machine_device_ids.append(device_name)
device_tag = parallel_conf.device_tag()
return device_tag, machine_device_ids
def MakeParallelConf(device_tag, machine_device_ids):
assert isinstance(machine_device_ids, (list, tuple))
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_device_id in machine_device_ids:
assert isinstance(
machine_device_id, str
), "type of machine_device_id (%s) is not string" % type(machine_device_id)
assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
"machine_device_id: %s is not valid" % machine_device_id
)
parallel_conf.add_device_name(machine_device_id)
return parallel_conf
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/parallel_conf_util.h"
namespace oneflow {
namespace {
std::pair<std::string, std::vector<std::string>> PyGetDeviceTagAndMachineDeviceIds(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) {
return *(GetDeviceTagAndMachineDeviceIds(parallel_conf).GetPtrOrThrow());
}
std::shared_ptr<cfg::ParallelConf> PyMakeParallelConf(
const std::string& device_tag, const std::vector<std::string>& machine_device_ids) {
return MakeParallelConf(device_tag, machine_device_ids).GetPtrOrThrow();
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("GetDeviceTagAndMachineDeviceIds", &PyGetDeviceTagAndMachineDeviceIds);
m.def("MakeParallelConf", &PyMakeParallelConf);
}
} // namespace oneflow
......@@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/job/job_conf.cfg.h"
#include "oneflow/core/job/placement.cfg.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/framework/parallel_conf_util.h"
namespace oneflow {
......@@ -399,6 +400,58 @@ Maybe<void> InstructionsBuilder::ReplaceMirrored(
return Maybe<void>::Ok();
}
Maybe<Scope> InstructionsBuilder::BuildInitialScope(
int64_t session_id, const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::string& device_tag, const std::vector<std::string>& machine_device_ids,
bool is_mirrored) {
std::shared_ptr<cfg::ScopeProto> scope_proto = std::make_shared<cfg::ScopeProto>();
scope_proto->set_session_id(session_id);
std::shared_ptr<JobDesc> job_conf_sym = JUST(GetJobConfSymbol(job_conf));
scope_proto->set_job_desc_symbol_id(JUST(job_conf_sym->symbol_id()));
std::shared_ptr<cfg::ParallelConf> parallel_conf =
JUST(MakeParallelConf(device_tag, machine_device_ids));
std::shared_ptr<ParallelDesc> device_parallel_desc_sym =
JUST(GetParallelDescSymbol(parallel_conf));
scope_proto->set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id()));
parallel_conf = JUST(MakeParallelConf("cpu", machine_device_ids));
std::shared_ptr<ParallelDesc> host_parallel_desc_sym = JUST(GetParallelDescSymbol(parallel_conf));
scope_proto->set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id()));
if (is_mirrored) {
scope_proto->mutable_opt_mirrored_parallel_conf()->mutable_mirrored_parallel();
} else {
scope_proto->mutable_opt_mirrored_parallel_conf()->clear_mirrored_parallel();
}
return GetScopeSymbol(scope_proto);
}
Maybe<Scope> InstructionsBuilder::BuildScopeWithNewParallelDesc(
const std::shared_ptr<Scope>& scope, const std::string& device_tag,
const std::vector<std::string>& machine_device_ids) {
const auto SetScopeProto =
[this, &device_tag,
&machine_device_ids](const std::shared_ptr<cfg::ScopeProto>& scope_proto) -> Maybe<void> {
std::shared_ptr<cfg::ParallelConf> parallel_conf =
JUST(MakeParallelConf(device_tag, machine_device_ids));
std::shared_ptr<ParallelDesc> device_parallel_desc_sym =
JUST(GetParallelDescSymbol(parallel_conf));
parallel_conf = JUST(MakeParallelConf("cpu", machine_device_ids));
std::shared_ptr<ParallelDesc> host_parallel_desc_sym =
JUST(GetParallelDescSymbol(parallel_conf));
scope_proto->set_device_parallel_desc_symbol_id(JUST(device_parallel_desc_sym->symbol_id()));
scope_proto->set_host_parallel_desc_symbol_id(JUST(host_parallel_desc_sym->symbol_id()));
return Maybe<void>::Ok();
};
return BuildScopeByProtoSetter(scope, SetScopeProto);
}
Maybe<Scope> InstructionsBuilder::BuildScopeWithNewParallelConf(
const std::shared_ptr<Scope>& scope, const std::shared_ptr<cfg::ParallelConf>& parallel_conf) {
std::pair<std::string, std::vector<std::string>> tag_and_dev_ids =
*JUST(GetDeviceTagAndMachineDeviceIds(parallel_conf));
return BuildScopeWithNewParallelDesc(scope, tag_and_dev_ids.first, tag_and_dev_ids.second);
}
Maybe<Scope> InstructionsBuilder::BuildScopeWithNewIsMirrored(const std::shared_ptr<Scope>& scope,
bool is_mirrored) {
const auto SetScopeProto = [is_mirrored](const std::shared_ptr<cfg::ScopeProto>& scope_proto) {
......
......@@ -118,6 +118,19 @@ class InstructionsBuilder {
std::vector<std::shared_ptr<compatible_py::BlobObject>> lhs_objects,
std::vector<std::shared_ptr<compatible_py::BlobObject>> rhs_objects);
Maybe<Scope> BuildInitialScope(int64_t session_id,
const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::string& device_tag,
const std::vector<std::string>& machine_device_ids,
bool is_mirrored);
Maybe<Scope> BuildScopeWithNewParallelDesc(const std::shared_ptr<Scope>& scope,
const std::string& device_tag,
const std::vector<std::string>& machine_device_ids);
Maybe<Scope> BuildScopeWithNewParallelConf(
const std::shared_ptr<Scope>& scope, const std::shared_ptr<cfg::ParallelConf>& parallel_conf);
Maybe<Scope> BuildScopeWithNewIsMirrored(const std::shared_ptr<Scope>& scope, bool is_mirrored);
Maybe<Scope> BuildScopeWithNewScopeName(const std::shared_ptr<Scope>& scope,
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/framework/parallel_conf_util.h"
namespace oneflow {
Maybe<std::pair<std::string, std::vector<std::string>>> GetDeviceTagAndMachineDeviceIds(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) {
std::vector<std::string> machine_device_ids;
for (const std::string& device_name : parallel_conf->device_name()) {
machine_device_ids.emplace_back(device_name);
}
return std::make_pair(parallel_conf->device_tag(), machine_device_ids);
}
Maybe<cfg::ParallelConf> MakeParallelConf(const std::string& device_tag,
const std::vector<std::string>& machine_device_ids) {
std::shared_ptr<cfg::ParallelConf> parallel_conf = std::make_shared<cfg::ParallelConf>();
parallel_conf->set_device_tag(device_tag);
for (const std::string& machine_device_id : machine_device_ids) {
size_t pos = machine_device_id.find(':');
CHECK_NE_OR_RETURN(pos, std::string::npos) << "device_name: " << machine_device_id;
std::string machine_id = machine_device_id.substr(0, pos);
CHECK_OR_RETURN(IsStrInt(machine_id));
std::string device_id = machine_device_id.substr(pos + 1);
size_t minus_pos = device_id.rfind('-');
if (minus_pos == std::string::npos) {
CHECK_OR_RETURN(IsStrInt(device_id));
} else {
std::string min_id = device_id.substr(0, minus_pos);
CHECK_OR_RETURN(IsStrInt(min_id));
std::string max_id = device_id.substr(minus_pos + 1);
CHECK_OR_RETURN(IsStrInt(max_id));
}
parallel_conf->add_device_name(machine_device_id);
}
return parallel_conf;
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_
#define ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_
#include <utility>
#include <vector>
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/job/placement.cfg.h"
namespace oneflow {
Maybe<std::pair<std::string, std::vector<std::string>>> GetDeviceTagAndMachineDeviceIds(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf);
Maybe<cfg::ParallelConf> MakeParallelConf(const std::string& device_tag,
const std::vector<std::string>& machine_device_ids);
} // namespace oneflow
#endif // ONEFLOW_CORE_FRAMEWORK_PRARLLEL_CONF_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <algorithm>
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/parallel_conf_util.h"
namespace oneflow {
namespace test {
TEST(ParallelConfUtil, MakeParallelConfSuccess) {
std::string device_tag = "cpu";
std::vector<std::string> machine_device_ids;
machine_device_ids.emplace_back("0:0-3");
machine_device_ids.emplace_back("1:0-3");
auto parallel_conf = CHECK_JUST(MakeParallelConf(device_tag, machine_device_ids));
ASSERT_EQ(parallel_conf->device_tag(), "cpu");
ASSERT_EQ(parallel_conf->device_name().size(), 2);
}
TEST(ParallelConfUtil, MakeParallelConfError) {
std::string device_tag = "cpu";
std::vector<std::string> machine_device_ids;
machine_device_ids.emplace_back("0:0-3");
machine_device_ids.emplace_back("1:0-");
auto parallel_conf = TRY(MakeParallelConf(device_tag, machine_device_ids));
ASSERT_EQ(parallel_conf.error()->has_check_failed_error(), true);
}
TEST(ParallelConfUtil, GetDeviceTagAndMachineDeviceIds) {
std::shared_ptr<cfg::ParallelConf> parallel_conf = std::make_shared<cfg::ParallelConf>();
parallel_conf->set_device_tag("cpu");
parallel_conf->add_device_name("0:0-1");
parallel_conf->add_device_name("0:2-3");
parallel_conf->add_device_name("1:0-1");
parallel_conf->add_device_name("1:2-3");
std::pair<std::string, std::vector<std::string>> tag_and_dev_ids =
*CHECK_JUST(GetDeviceTagAndMachineDeviceIds(parallel_conf));
std::string device_tag = tag_and_dev_ids.first;
std::vector<std::string> machine_device_ids = tag_and_dev_ids.second;
ASSERT_EQ(device_tag, "cpu");
ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "0:0-1"), 0);
ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "0:2-3"), 0);
ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "1:0-1"), 0);
ASSERT_NE(std::count(machine_device_ids.begin(), machine_device_ids.end(), "1:2-3"), 0);
ASSERT_EQ(std::count(machine_device_ids.begin(), machine_device_ids.end(), "2:0-3"), 0);
}
} // namespace test
} // namespace oneflow
......@@ -20,7 +20,7 @@ limitations under the License.
namespace oneflow {
namespace test {
TEST(parallel_desc, continuous_1n4d) {
TEST(ParallelDesc, continuous_1n4d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-3");
......@@ -29,7 +29,7 @@ TEST(parallel_desc, continuous_1n4d) {
ASSERT_EQ(parallel_desc.parallel_num(), 4);
}
TEST(parallel_desc, discrete_1n4d) {
TEST(ParallelDesc, discrete_1n4d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-1");
......@@ -39,7 +39,7 @@ TEST(parallel_desc, discrete_1n4d) {
ASSERT_EQ(parallel_desc.parallel_num(), 4);
}
TEST(parallel_desc, continuous_2n8d) {
TEST(ParallelDesc, continuous_2n8d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-3");
......@@ -49,7 +49,7 @@ TEST(parallel_desc, continuous_2n8d) {
ASSERT_EQ(parallel_desc.parallel_num(), 8);
}
TEST(parallel_desc, discrete_2n8d) {
TEST(ParallelDesc, discrete_2n8d) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0-1");
......
......@@ -30,7 +30,6 @@ import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.object_storage as object_storage
import oneflow.python.eager.symbol as symbol_util
import oneflow.python.eager.symbol_storage as symbol_storage
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow_api.oneflow.core.job.scope as scope_cfg
import oneflow.python.framework.balanced_splitter as balanced_splitter
import oneflow.python.framework.c_api_util as c_api_util
......@@ -286,50 +285,6 @@ def MakeLazyRefBlobObject(self, interface_op_name):
return blob_object
def BuildInitialScope(
self, session_id, job_conf, device_tag, machine_device_ids, is_mirrored,
):
scope_proto = scope_cfg.ScopeProto()
scope_proto.set_session_id(session_id)
job_conf_sym = self.GetJobConfSymbol(job_conf)
scope_proto.set_job_desc_symbol_id(job_conf_sym.symbol_id)
parallel_conf = parallel_conf_util.MakeParallelConf(device_tag, machine_device_ids)
device_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_device_parallel_desc_symbol_id(device_parallel_desc_sym.symbol_id)
parallel_conf = parallel_conf_util.MakeParallelConf("cpu", machine_device_ids)
host_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_host_parallel_desc_symbol_id(host_parallel_desc_sym.symbol_id)
if is_mirrored:
scope_proto.mutable_opt_mirrored_parallel_conf().mutable_mirrored_parallel()
else:
scope_proto.mutable_opt_mirrored_parallel_conf().clear_mirrored_parallel()
return self.GetScopeSymbol(scope_proto)
def BuildScopeWithNewParallelDesc(self, scope, device_tag, machine_device_ids):
if isinstance(machine_device_ids, str):
machine_device_ids = [machine_device_ids]
def SetScopeProto(scope_proto):
parallel_conf = parallel_conf_util.MakeParallelConf(
device_tag, machine_device_ids
)
device_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
parallel_conf = parallel_conf_util.MakeParallelConf("cpu", machine_device_ids)
host_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_device_parallel_desc_symbol_id(
device_parallel_desc_sym.symbol_id
)
scope_proto.set_host_parallel_desc_symbol_id(host_parallel_desc_sym.symbol_id)
return self.BuildScopeByProtoSetter(scope, SetScopeProto)
def BuildScopeWithNewParallelConf(self, scope, parallel_conf):
tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(parallel_conf)
return self.BuildScopeWithNewParallelDesc(scope, *tag_and_dev_ids)
def GetSharedOpKernelObject4ParallelConfSymbol(self, parallel_desc_sym):
if object_storage.HasSharedOpKernelObject4ParallelConfSymbol(parallel_desc_sym):
return object_storage.GetSharedOpKernelObject4ParallelConfSymbol(
......@@ -827,13 +782,6 @@ def RegisterMethod4InstructionsBuilder():
oneflow_api.deprecated.InstructionsBuilder.MakeLazyRefBlobObject = (
MakeLazyRefBlobObject
)
oneflow_api.deprecated.InstructionsBuilder.BuildInitialScope = BuildInitialScope
oneflow_api.deprecated.InstructionsBuilder.BuildScopeWithNewParallelDesc = (
BuildScopeWithNewParallelDesc
)
oneflow_api.deprecated.InstructionsBuilder.BuildScopeWithNewParallelConf = (
BuildScopeWithNewParallelConf
)
oneflow_api.deprecated.InstructionsBuilder.GetSharedOpKernelObject4ParallelConfSymbol = (
GetSharedOpKernelObject4ParallelConfSymbol
)
......
......@@ -19,7 +19,6 @@ from contextlib import contextmanager
import inspect
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow.python.framework.distribute as distribute_util
import oneflow.python.framework.input_blob_def as input_blob_util
import oneflow.python.framework.hob as hob
......
......@@ -22,9 +22,9 @@ import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.op_util as op_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow
import oneflow_api.oneflow.core.job.placement as placement_cfg
import oneflow_api
class PlacementScope(object):
......@@ -75,7 +75,7 @@ def MakeParallelConf4Resource(device_tag, resource):
machine_device_ids = GetCpuMachineDeviceIds(resource)
else:
raise NotImplementedError
return parallel_conf_util.MakeParallelConf(device_tag, machine_device_ids)
return oneflow_api.MakeParallelConf(device_tag, machine_device_ids)
def MakeMachineId2DeviceIdList(parallel_conf):
......
......@@ -97,6 +97,10 @@ def GetEmptyPlacementScope(device_tag, machine_device_ids):
@enable_if.condition(hob.in_normal_mode & hob.session_initialized)
def GetNormalModePlacementScope(device_tag, machine_device_ids):
if isinstance(machine_device_ids, tuple):
machine_device_ids = list(machine_device_ids)
if not isinstance(machine_device_ids, list):
machine_device_ids = [machine_device_ids]
sess = session_ctx.GetDefaultSession()
scope = scope_util.MakeScope(
lambda old_scope, builder: builder.BuildScopeWithNewParallelDesc(
......
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from oneflow.python.eager.symbol import Symbol
import oneflow.python.eager.symbol_storage as symbol_storage
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow_api.oneflow.core.job.scope as scope_cfg
import oneflow_api.oneflow.core.job.placement as placement_cfg
import oneflow_api
......@@ -90,9 +89,7 @@ class ScopeSymbol(Symbol):
return self.BuildBySetter(instruction_builder, SetScopeProto)
def BuildWithNewParallelConf(self, instruction_builder, parallel_conf):
tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
parallel_conf
)
tag_and_dev_ids = oneflow_api.GetDeviceTagAndMachineDeviceIds(parallel_conf)
return self.BuildWithNewParallelDesc(instruction_builder, *tag_and_dev_ids)
def BuildWithNewIsMirrored(self, instruction_builder, is_mirrored):
......
......@@ -26,7 +26,6 @@ import oneflow.python.framework.interpret_util as interpret_util
import oneflow.python.framework.hob as hob
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.interpret_util as interpret_util
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow.python.framework.placement_context as placement_ctx
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.lib.core.enable_if as enable_if
......@@ -64,7 +63,7 @@ def api_system_assign(ref, value, validate_shape=None, use_locking=None, name=No
@enable_if.condition(hob.in_global_mode & ~hob.eager_execution_enabled)
def lazy_system_assign(ref, value, validate_shape=None, use_locking=None, name=None):
op_conf = _SystemAssignOpConf(ref, value, name=name)
device_tag, machine_device_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
device_tag, machine_device_ids = oneflow_api.GetDeviceTagAndMachineDeviceIds(
ref.parallel_conf
)
with oneflow.scope.placement(device_tag, machine_device_ids):
......
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
import uuid
from typing import Callable, Optional, Union
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.session_context as session_ctx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册