未验证 提交 380d2414 编写于 作者: Z Zhanghuihong Guan 提交者: GitHub

Change maybe to optional (#6611)

* initial commit, add code for async construct tensor from numpy array

* inital commit to change Maybe to Optional

* delete redundant code

* replace Maybe with Optional

* fix compile errors

* format code

* changes based on review

* format code, fix based on review

* format code

* fix multiclient type

* changes based on review

* changes based on review

* unify calling to IsMultiClirnt

* refector multi_client related code

* restore InMultiClient interface

* double check for unnecessary changes

* remove unnecessary changes

* format code

* Update oneflow/api/python/symbol/job_conf_symbol.cpp

* Update oneflow/api/python/symbol/op_conf_symbol.cpp

* Update oneflow/api/python/symbol/op_node_signature_symbol.cpp

* Update oneflow/core/common/optional.h

* Update oneflow/api/python/symbol/string_symbol.cpp

* Update oneflow/api/python/symbol/scope_symbol.cpp

* Update oneflow/api/python/symbol/placement_symbol.cpp

* Update oneflow/api/python/symbol/op_conf_symbol.cpp
Co-authored-by: NHoujiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: NTwice <i@twice.moe>
上级 8c619789
......@@ -18,12 +18,13 @@ limitations under the License.
#include <string>
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/cluster.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/rpc/include/base.h"
......@@ -45,19 +46,6 @@ inline Maybe<void> EnableEagerEnvironment(bool enable_eager_execution) {
return Maybe<void>::Ok();
}
inline Maybe<bool>* IsMultiClientPtr() { return Global<Maybe<bool>, MultiClient>::Get(); }
inline Maybe<bool> IsMultiClient() {
CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr());
return *IsMultiClientPtr();
}
inline Maybe<void> SetIsMultiClient(bool is_multi_client) {
CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr());
*IsMultiClientPtr() = is_multi_client;
return Maybe<void>::Ok();
}
inline Maybe<bool> IsEnvInited() { return Global<EnvGlobalObjectsScope>::Get() != nullptr; }
inline Maybe<void> DestroyEnv() {
......
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_conf.cfg.h"
......@@ -36,7 +37,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
return CreateJobConfSymbol(symbol_id, symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id",
[](const JobDesc& x) { return x.symbol_id().GetOrThrow(); })
[](const JobDesc& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("data", &JobDesc::cfg_job_conf);
}
......
......@@ -15,8 +15,10 @@ limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/operator/op_conf_symbol.h"
#include "oneflow/core/common/maybe.h"
namespace py = pybind11;
......@@ -25,7 +27,12 @@ namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<OperatorConfSymbol, std::shared_ptr<OperatorConfSymbol>>(m, "OpConfSymbol")
.def_property_readonly("symbol_id",
[](const OperatorConfSymbol& x) { return x.symbol_id().GetOrThrow(); })
[](const OperatorConfSymbol& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("data", &OperatorConfSymbol::data);
}
......
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/operator/op_node_signature_desc.h"
#include "oneflow/core/operator/op_node_signature.pb.h"
......@@ -37,8 +38,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
py::init([](int64_t symbol_id, const std::shared_ptr<cfg::OpNodeSignature>& symbol_conf) {
return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly(
"symbol_id", [](const OpNodeSignatureDesc& x) { return x.symbol_id().GetOrThrow(); })
.def_property_readonly("symbol_id",
[](const OpNodeSignatureDesc& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def("data", &OpNodeSignatureDesc::op_node_signature);
}
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/control/global_process_ctx.h"
......@@ -198,7 +199,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
hierarchy);
}))
.def_property_readonly("symbol_id",
[](const ParallelDesc& x) { return x.symbol_id().GetOrThrow(); })
[](const ParallelDesc& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("parallel_conf", &ParallelDesc::cfg_parallel_conf)
.def_property_readonly("parallel_num", &ParallelDesc::parallel_num)
.def_property_readonly("device_tag", &ParallelDesc::device_tag)
......
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/scope.h"
......@@ -36,7 +37,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def(py::init([](int64_t symbol_id, const std::shared_ptr<cfg::ScopeProto>& symbol_conf) {
return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id", [](const Scope& x) { return x.symbol_id().GetOrThrow(); })
.def_property_readonly("symbol_id",
[](const Scope& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("_proto_str",
[](const Scope& x) { return PbMessage2TxtString(x.scope_proto()); })
.def("auto_increment_id", &Scope::auto_increment_id)
......
......@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/vm/string_symbol.h"
namespace py = pybind11;
......@@ -31,7 +33,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
return CreateStringSymbol(symbol_id, data).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id",
[](const StringSymbol& x) { return x.symbol_id().GetOrThrow(); })
[](const StringSymbol& x) {
if (!x.symbol_id().has_value()) {
THROW(RuntimeError) << "symbol_id not initialized";
}
return CHECK_JUST(x.symbol_id());
})
.def_property_readonly("data", &StringSymbol::data);
}
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_MULTICLIENT_H_
#define ONEFLOW_CORE_COMMON_MULTICLIENT_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
inline Optional<bool>* IsMultiClientPtr() { return Global<Optional<bool>, MultiClient>::Get(); }
inline Maybe<bool> IsMultiClient() {
auto* opt = Global<Optional<bool>, MultiClient>::Get();
return !opt || opt->value_or(true);
}
inline Maybe<void> SetIsMultiClient(bool is_multi_client) {
CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr());
*IsMultiClientPtr() = is_multi_client;
return Maybe<void>::Ok();
}
} // namespace oneflow
#endif
......@@ -16,6 +16,7 @@ limitations under the License.
#include <mutex>
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/device/node_device_descriptor_manager.h"
#include "oneflow/core/device/cuda_device_descriptor.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
......@@ -180,7 +181,7 @@ void CublasMathModeGuard::SetMathMode(cublasMath_t new_mode) {
int GetCudaDeviceIndex() {
int cuda_device_index = 0;
if (CHECK_JUST(GlobalMultiClientEnv())) {
if (CHECK_JUST(IsMultiClient())) {
cuda_device_index = GlobalProcessCtx::LocalRank();
} else {
OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index));
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/eager/eager_oneflow.h"
......@@ -24,7 +25,6 @@ limitations under the License.
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/vm/string_symbol.h"
#include "oneflow/core/eager/eager_symbol.cfg.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/cluster_instruction.h"
......@@ -94,7 +94,7 @@ Maybe<void> EagerOneflow::RunPhysicalInstruction(vm::InstructionMsgList* instruc
Maybe<void> EagerOneflow::RunLogicalInstruction(vm::InstructionMsgList* instruction_list,
const vm::cfg::EagerSymbolList& eager_symbol_list) {
if (JUST(GlobalMultiClientEnv())) {
if (JUST(IsMultiClient())) {
// NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun,
// because each rank will process instructions ONLY from itself, NOT the master.
return RunPhysicalInstruction(instruction_list, eager_symbol_list);
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <atomic>
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/symbol_storage_util.h"
#include "oneflow/core/eager/eager_symbol.cfg.h"
......@@ -42,7 +43,6 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/instruction_replay.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
......@@ -1750,7 +1750,7 @@ InstructionsBuilder::GetMut2OperandBlobObjects(
}
Maybe<void> LogicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {
if (JUST(GlobalMultiClientEnv())) {
if (JUST(IsMultiClient())) {
// NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun,
// because each rank will process instructions ONLY from itself, NOT the master.
return PhysicalRun(Build);
......
......@@ -33,7 +33,7 @@ namespace {
class TestVirtualMachineScope {
public:
TestVirtualMachineScope(int64_t gpu_device_num, int64_t cpu_device_num) {
*Global<Maybe<bool>, MultiClient>::Get() = false;
*Global<Optional<bool>, MultiClient>::Get() = false;
test_resource_desc_scope_.reset(new vm::TestResourceDescScope(gpu_device_num, cpu_device_num));
virtual_machine_scope_.reset(
new vm::VirtualMachineScope(Global<ResourceDesc, ForSession>::Get()->resource()));
......@@ -42,8 +42,7 @@ class TestVirtualMachineScope {
~TestVirtualMachineScope() {
virtual_machine_scope_.reset();
test_resource_desc_scope_.reset();
Global<Maybe<bool>, MultiClient>::SetAllocated(
new Maybe<bool>(Error::InvalidValueError("is_multi_client is not set")));
Global<Optional<bool>, MultiClient>::SetAllocated(new Optional<bool>());
}
private:
......
......@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/load_library.h"
#include "oneflow/core/job/version.h"
......@@ -27,7 +29,6 @@ limitations under the License.
#include "oneflow/core/memory/memory_allocator.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/user/summary/events_writer.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/memory/chunk_manager.h"
#include "oneflow/core/vm/vm_util.h"
......@@ -55,7 +56,7 @@ int32_t GetGpuDeviceNum() {
Maybe<void> MultiClientSessionContext::TryInit(const ConfigProto& config_proto) {
if (!is_inited_) {
CHECK_OR_RETURN(JUST(GlobalMultiClientEnv()));
CHECK_OR_RETURN(JUST(IsMultiClient()));
DumpVersionInfo();
Resource resource = config_proto.resource();
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/global.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/attr_map.h"
......@@ -126,7 +127,7 @@ class ConsistentRandFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (!JUST(IsMultiClient())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", nd_sbp->DebugString()));
}
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
......@@ -198,7 +199,7 @@ class ConsistentRandNFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (!JUST(IsMultiClient())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", nd_sbp->DebugString()));
}
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/inplace_lbi_graph.h"
#include "oneflow/core/graph/id_serialization.h"
......@@ -26,7 +27,6 @@ limitations under the License.
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
#include "oneflow/core/graph/stream_index_getter_registry_manager.h"
......@@ -544,7 +544,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
}
void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() {
if (!CHECK_JUST(GlobalMultiClientEnv())) { return; }
if (!CHECK_JUST(IsMultiClient())) { return; }
HashMap<int64_t, TaskNode*> rank_id2src_tick;
HashMap<int64_t, TaskNode*> rank_id2dst_tick;
HashMap<int64_t, HashSet<TaskNode*>> rank_id2input_output_nodes;
......
......@@ -50,10 +50,4 @@ int64_t EnvDesc::GetMachineId(const std::string& addr) const {
return machine_id;
}
Maybe<bool> GlobalMultiClientEnv() {
Maybe<bool>* is_multi_client = Global<Maybe<bool>, MultiClient>::Get();
CHECK_NOTNULL_OR_RETURN(is_multi_client);
return *is_multi_client;
}
} // namespace oneflow
......@@ -44,8 +44,6 @@ class EnvDesc final {
EnvProto env_proto_;
};
Maybe<bool> GlobalMultiClientEnv();
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_CLUSTER_DESC_H_
......@@ -16,12 +16,12 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
COMMAND(Global<bool, EagerExecution>::SetAllocated(new bool(false)));
COMMAND(Global<Maybe<bool>, MultiClient>::SetAllocated(
new Maybe<bool>(Error::InvalidValueError("is_multi_client is not set"))));
COMMAND(Global<Optional<bool>, MultiClient>::SetAllocated(new Optional<bool>()));
} // namespace oneflow
......@@ -15,10 +15,10 @@ limitations under the License.
*/
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/common/util.h"
#include <json.hpp>
namespace oneflow {
......@@ -106,7 +106,7 @@ Maybe<void> EagerJobBuildAndInferCtxMgr::VirtualCloseJob() {
bool EagerExecutionEnabled() { return *Global<bool, EagerExecution>::Get(); }
Maybe<JobBuildAndInferCtxMgr*> GlobalJobBuildAndInferCtxMgr() {
if (JUST(GlobalMultiClientEnv())) {
if (JUST(IsMultiClient())) {
return JUST(GlobalMaybe<LazyJobBuildAndInferCtxMgr>());
} else {
// single-client
......
......@@ -43,14 +43,14 @@ void CheckFunctionConfig(const JobConfigProto& job_conf) {
} // namespace
JobDesc::JobDesc(const JobConfigProto& job_conf, int64_t job_id)
: job_conf_(job_conf), job_id_(job_id), symbol_id_(Error::SymbolIdUninitializedError()) {
: job_conf_(job_conf), job_id_(job_id), symbol_id_(NullOpt) {
CHECK_JUST(Init());
Global<ResourceDesc, ForSession>::Get()->DumpCudnnConf(job_conf);
}
Maybe<JobDesc> JobDesc::New(int64_t symbol_id, const JobConfigProto& job_conf) {
auto job_desc = std::make_shared<JobDesc>(job_conf);
job_desc->symbol_id_ = Maybe<int64_t>(symbol_id);
job_desc->symbol_id_ = symbol_id;
return job_desc;
}
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_JOB_JOB_DESC_H_
#define ONEFLOW_CORE_JOB_JOB_DESC_H_
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/dlnet_conf.pb.h"
#include "oneflow/core/job/job.pb.h"
......@@ -41,7 +42,7 @@ class JobDesc final {
~JobDesc() = default;
static Maybe<JobDesc> New(int64_t symbol_id, const JobConfigProto& job_conf);
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
const std::shared_ptr<cfg::JobConfigProto>& cfg_job_conf() const { return cfg_job_conf_; }
// Common
......@@ -84,7 +85,7 @@ class JobDesc final {
JobConfigProto job_conf_;
int64_t job_id_;
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
// merge job_conf_ and cfg_job_conf_ after cfg::JobConfigProto taken as a constructor argument
std::shared_ptr<cfg::JobConfigProto> cfg_job_conf_;
};
......
......@@ -71,8 +71,7 @@ Maybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf) {
return machine2device_list;
}
ParallelDesc::ParallelDesc(const ParallelConf& user_conf)
: symbol_id_(Error::SymbolIdUninitializedError()) {
ParallelDesc::ParallelDesc(const ParallelConf& user_conf) : symbol_id_(NullOpt) { // NOLINT
CHECK_JUST(MaybeInit(user_conf));
CHECK_JUST(CheckWithResourceDesc(*(Global<ResourceDesc, ForSession>::Get())));
}
......
......@@ -58,7 +58,7 @@ class ParallelDesc final {
Maybe<void> MaybeInit(const ParallelConf& user_conf);
// Getters
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
bool containing_current_rank() const { return containing_current_rank_; }
DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return parallel_conf_.device_tag(); }
......@@ -108,7 +108,7 @@ class ParallelDesc final {
private:
friend Maybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf);
ParallelDesc() : symbol_id_(Error::SymbolIdUninitializedError()) {}
ParallelDesc() : symbol_id_(NullOpt) {}
ParallelDesc(int64_t symbol_id) : symbol_id_(symbol_id) {}
void ClearUp();
Maybe<void> SetMachineIdAndDeviceIdsByParsingDeviceName(const std::string& device_name,
......@@ -117,7 +117,7 @@ class ParallelDesc final {
Maybe<void> CheckWithResourceDesc(const ResourceDesc& resource_desc);
bool EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const;
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
DeviceType device_type_;
ParallelConf parallel_conf_;
std::shared_ptr<Shape> hierarchy_;
......
......@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/job/plan_util.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/graph/plan_task_graph.h"
#include "oneflow/core/graph/boxing/collective_boxing_util.h"
#include "oneflow/core/memory/chunk_manager.h"
......@@ -325,7 +325,7 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(
}
}
if (CHECK_JUST(GlobalMultiClientEnv())) {
if (CHECK_JUST(IsMultiClient())) {
GenChunkForMultiNNGraphMemoryReuseInMultiClient(plan, &mem_block_id2mem_block);
} else {
CHECK(variable_op_names.empty());
......
......@@ -26,9 +26,7 @@ limitations under the License.
namespace oneflow {
Scope::Scope(const ScopeProto& scope_proto)
: auto_increment_id_(0),
symbol_id_(Error::SymbolIdUninitializedError()),
scope_proto_(scope_proto) {
: auto_increment_id_(0), symbol_id_(NullOpt), scope_proto_(scope_proto) {
CHECK_OK(Init()) << scope_proto_.DebugString();
}
......
......@@ -22,6 +22,7 @@ limitations under the License.
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/framework/attr_value.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/symbol.h"
namespace oneflow {
......@@ -40,7 +41,7 @@ class Scope final {
~Scope() = default;
static Maybe<Scope> New(int64_t symbol_id, const ScopeProto& scope_proto);
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
int64_t auto_increment_id() { return ++auto_increment_id_; }
int64_t session_id() const { return scope_proto().session_id(); }
const std::shared_ptr<JobDesc>& job_desc_symbol() const { return job_desc_; }
......@@ -78,7 +79,7 @@ class Scope final {
const AttrValue& GetAttrValue(const std::string& attr_name) const;
int64_t auto_increment_id_;
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
const ScopeProto scope_proto_;
std::shared_ptr<JobDesc> job_desc_;
Symbol<PlacementScope> placement_scope_;
......
......@@ -13,11 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job_rewriter/autotick.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/critical_section_desc.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -562,7 +563,7 @@ Maybe<void> AutoSourceAndSinkTick(
}
Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
if (JUST(IsMultiClient())) { return Maybe<void>::Ok(); }
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
critical_section->mutable_total_job_critical_section();
......@@ -581,7 +582,7 @@ Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilde
}
Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) {
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
if (!JUST(IsMultiClient())) { return Maybe<void>::Ok(); }
HashMap<int64_t, std::string> machine_id2src_op_name;
HashMap<int64_t, std::string> machine_id2sink_op_name;
{
......@@ -610,7 +611,7 @@ Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job)
Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder) {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
if (JUST(IsMultiClient())) { return Maybe<void>::Ok(); }
JUST(ForEachInputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
......@@ -623,7 +624,7 @@ Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder) {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
if (JUST(IsMultiClient())) { return Maybe<void>::Ok(); }
JUST(ForEachOutputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
......
......@@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
......@@ -34,7 +34,7 @@ Maybe<void> GradientAccumulationRewritePass::Apply(Job* job, JobPassCtx* ctx) co
|| job_conf.num_gradient_accumulation_steps() <= 1) {
return Maybe<void>::Ok();
}
const bool is_multi_client = CHECK_JUST(GlobalMultiClientEnv());
const bool is_multi_client = CHECK_JUST(IsMultiClient());
const OpGraph op_graph(*job);
JobBuilder job_builder(job);
HashMap<std::string, OperatorConf> name2op_conf;
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/job_rewriter/job_completer.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job_rewriter/autograd.h"
......@@ -35,7 +36,7 @@ Maybe<void> CheckOpGraph(const OpGraph& op_graph) {
// NOTE(chengcheng):
// in single-client source op is SourceTickOpConf,
// in multi-client source op is WaitAndSendIdsOpConf_
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (JUST(IsMultiClient())) {
CHECK_OR_RETURN(op_node->op().op_conf().has_wait_and_send_ids_conf());
} else {
CHECK_OR_RETURN(op_node->op().op_conf().has_source_tick_conf());
......@@ -49,7 +50,7 @@ Maybe<void> CheckOpGraph(const OpGraph& op_graph) {
// NOTE(chengcheng):
// in single-client source op is SinkTickOpConf,
// in multi-client source op is CallbackNotifyOpConf.
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (JUST(IsMultiClient())) {
CHECK_OR_RETURN(op_node->op().op_conf().has_callback_notify_conf());
} else {
CHECK_OR_RETURN(op_node->op().op_conf().has_sink_tick_conf());
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/job_builder.h"
......@@ -31,7 +32,7 @@ class SetDefaultVariableConf final : public JobPass {
}
Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (JUST(IsMultiClient())) {
// NOTE(chengcheng): Multi-Client Variable is inited by Eager.
return Maybe<void>::Ok();
}
......
......@@ -37,7 +37,7 @@ template<typename T>
void CallbackNotifyKernel<T>::ForwardDataContent(KernelContext* ctx) const {
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
std::string buffer_name;
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().callback_notify_conf().has_job_name());
buffer_name = GetCallbackNotifierBufferName(this->op_conf().callback_notify_conf().job_name());
} else {
......
......@@ -31,7 +31,7 @@ class InputKernel final : public Kernel {
private:
void ForwardDataContent(KernelContext* ctx) const override {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().input_conf().has_job_name());
const auto& job_name = this->op_conf().input_conf().job_name();
const auto& op_name = this->op_conf().name();
......
......@@ -32,7 +32,7 @@ class OutputKernel final : public Kernel {
};
void OutputKernel::ForwardDataContent(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().output_conf().has_job_name());
const auto& job_name = this->op_conf().output_conf().job_name();
const auto& op_name = this->op_conf().name();
......@@ -51,7 +51,7 @@ void OutputKernel::ForwardDataContent(KernelContext* ctx) const {
}
void OutputKernel::ForwardHeader(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
// Do nothing.
} else {
ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->BnInOp2Blob("in"));
......
......@@ -32,7 +32,7 @@ class ReturnKernel final : public Kernel {
};
void ReturnKernel::ForwardDataContent(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().return_conf().has_job_name());
const auto& job_name = this->op_conf().return_conf().job_name();
const auto& op_name = this->op_conf().name();
......@@ -52,7 +52,7 @@ void ReturnKernel::ForwardDataContent(KernelContext* ctx) const {
}
void ReturnKernel::ForwardHeader(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
// Do nothing.
} else {
ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->BnInOp2Blob("in"));
......
......@@ -31,7 +31,7 @@ void WaitAndSendIdsKernel<T>::ForwardDataContent(KernelContext* ctx) const {
auto* status = CHECK_NOTNULL(dynamic_cast<WaitAndSendIdsStatus*>(ctx->state().get()));
const auto& conf = this->op_conf().wait_and_send_ids_conf();
if (status->out_idx_ >= status->out_num_) {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().wait_and_send_ids_conf().has_job_name());
const auto& job_name = this->op_conf().wait_and_send_ids_conf().job_name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
......@@ -53,7 +53,7 @@ void WaitAndSendIdsKernel<T>::ForwardDataContent(KernelContext* ctx) const {
}
}
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) {
*ctx->BnInOp2Blob("out")->mut_dptr<T>() = 0;
} else {
*ctx->BnInOp2Blob("out")->mut_dptr<T>() = conf.id_list(status->in_id_).value(status->out_idx_);
......
......@@ -17,7 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_
#include <string>
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/op_conf.cfg.h"
......@@ -32,11 +32,11 @@ class OperatorConfSymbol final {
~OperatorConfSymbol() = default;
const OperatorConf& op_conf() const { return op_conf_; }
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
std::shared_ptr<cfg::OperatorConf> data() const { return data_; }
private:
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
OperatorConf op_conf_;
std::shared_ptr<cfg::OperatorConf> data_;
};
......
......@@ -31,7 +31,7 @@ class OpNodeSignatureDesc final {
OpNodeSignatureDesc(OpNodeSignatureDesc&&) = delete;
OpNodeSignatureDesc(int64_t symbol_id, const OpNodeSignature& op_node_signature);
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
const std::shared_ptr<cfg::OpNodeSignature>& op_node_signature() const {
return op_node_signature_;
}
......@@ -43,7 +43,7 @@ class OpNodeSignatureDesc final {
Maybe<const BlobDesc&> LogicalBlobDesc4BnInOp(const std::string& bn_in_op) const;
private:
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
std::shared_ptr<cfg::OpNodeSignature> op_node_signature_;
HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2blob_desc_;
};
......
......@@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/operator/interface_op_util.h"
#include "oneflow/core/operator/output_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/operator/interface_op_util.h"
namespace oneflow {
......@@ -31,7 +32,7 @@ Maybe<void> OutputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
if (CHECK_JUST(GlobalMultiClientEnv())) {
if (CHECK_JUST(IsMultiClient())) {
*out_blob_desc = *BlobDesc4BnInOp("in");
} else {
JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(),
......@@ -45,7 +46,7 @@ Maybe<void> OutputOp::InferOutBlobDescs(
const ParallelContext* parallel_ctx) const {
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
if (CHECK_JUST(GlobalMultiClientEnv())) {
if (CHECK_JUST(IsMultiClient())) {
// NOTE(chengcheng):
// In multi-client, in blob shape maybe changed and NOT equal with output_conf.blob_conf,
// and the output op actually is return op (used in single-client) with NO blob conf.
......
......@@ -15,10 +15,10 @@ limitations under the License.
*/
#ifdef RPC_BACKEND_GRPC
#include "oneflow/core/rpc/include/grpc.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/control/ctrl_bootstrap.h"
#include "oneflow/core/control/ctrl_server.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/rpc/include/grpc.h"
namespace oneflow {
......
......@@ -13,16 +13,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/vm/id_generator.h"
#include "oneflow/core/vm/id_util.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
namespace vm {
Maybe<int64_t> LogicalIdGenerator::NewSymbolId() {
if (JUST(GlobalMultiClientEnv())) {
if (JUST(IsMultiClient())) {
// NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to
// PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master.
return IdUtil::NewPhysicalSymbolId(GlobalProcessCtx::Rank());
......@@ -32,7 +32,7 @@ Maybe<int64_t> LogicalIdGenerator::NewSymbolId() {
}
Maybe<int64_t> LogicalIdGenerator::NewObjectId() {
if (JUST(GlobalMultiClientEnv())) {
if (JUST(IsMultiClient())) {
// NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to
// PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master.
return IdUtil::NewPhysicalObjectId(GlobalProcessCtx::Rank());
......
......@@ -40,7 +40,7 @@ Maybe<void> ForEachThreadCtx(vm::VirtualMachine* vm,
void GetSchedulerThreadInitializer(std::function<void()>* Initializer) {
*Initializer = [&]() {
if (!CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return; }
if (!CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) { return; }
CHECK_JUST(InitThisThreadUniqueConsistentId(kThreadConsistentIdScheduler, "scheduler"));
};
}
......@@ -76,7 +76,7 @@ void GetWorkerThreadInitializer(intrusive::shared_ptr<vm::VirtualMachine> vm,
stream_type_index2consistent_id[stream_type_index] = thread_consistent_id++;
}
*Initializer = [stream_type_index2consistent_id](vm::ThreadCtx* thread_ctx) {
if (!CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return; }
if (!CHECK_JUST(*Global<Optional<bool>, MultiClient>::Get())) { return; }
const auto& stream_type_index = GetStreamTypeIndex(thread_ctx);
const auto& iter = stream_type_index2consistent_id.find(stream_type_index);
if (iter != stream_type_index2consistent_id.end()) {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_VM_STRING_DESC_H_
#include <string>
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/optional.h"
namespace oneflow {
......@@ -29,11 +29,11 @@ class StringSymbol final {
~StringSymbol() = default;
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
const Optional<int64_t>& symbol_id() const { return symbol_id_; }
const std::string& data() const { return data_; }
private:
Maybe<int64_t> symbol_id_;
Optional<int64_t> symbol_id_;
std::string data_;
};
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/cluster_instruction.h"
......@@ -42,7 +43,7 @@ Maybe<void> Run(vm::InstructionMsgList* instr_msg_list) {
Maybe<void> ClusterSync() {
Maybe<void> (*Run)(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) =
JUST(*Global<Maybe<bool>, MultiClient>::Get()) ? &PhysicalRun : &LogicalRun;
JUST(IsMultiClient()) ? &PhysicalRun : &LogicalRun;
BlockingCounter bc(1);
JUST(Run([&bc](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->ComputeGlobalFrontSeqBarrier());
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/user/data/coco_data_reader.h"
#include "oneflow/user/data/coco_dataset.h"
#include "oneflow/user/data/distributed_training_dataset.h"
......@@ -21,7 +22,6 @@ limitations under the License.
#include "oneflow/core/persistence/file_system.h"
#include "oneflow/core/persistence/persistent_in_stream.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
namespace data {
......@@ -37,7 +37,7 @@ COCODataReader::COCODataReader(user_op::KernelInitContext* ctx) : DataReader<COC
// NOTE(zwx): COCODataReader is not consistent since attr nd_sbp is empty,
// we assume that it works in DDP
auto nd_sbp_str_vec = ctx->Attr<std::vector<std::string>>("nd_sbp");
if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) {
if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) {
parallel_id = GlobalProcessCtx::Rank();
parallel_num = GlobalProcessCtx::WorldSize();
} else {
......
......@@ -16,14 +16,15 @@ limitations under the License.
#ifndef ONEFLOW_USER_DATA_OFRECORD_DATASET_H_
#define ONEFLOW_USER_DATA_OFRECORD_DATASET_H_
#include "oneflow/user/data/dataset.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/framework/op_kernel.h"
#include "oneflow/core/persistence/persistent_in_stream.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/user/data/dataset.h"
namespace oneflow {
namespace data {
......@@ -61,7 +62,7 @@ class OFRecordDataset final : public Dataset<TensorBuffer> {
auto nd_sbp_str_vec = ctx->Attr<std::vector<std::string>>("nd_sbp");
// NOTE(zwx): OFRecordDataset is not consistent since attr nd_sbp is empty,
// we assume that it works in DDP
if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) { is_local = true; }
if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) { is_local = true; }
}
if (is_local) {
parallel_id_ = GlobalProcessCtx::Rank();
......
......@@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/data/gpt_dataset.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
......@@ -74,7 +74,7 @@ class GPTDataLoader final : public OpKernelState {
// NOTE(zwx): GPTDataLoader is not consistent since attr nd_sbp is empty,
// we assume that it works in DDP
auto nd_sbp_str_vec = ctx->Attr<std::vector<std::string>>("nd_sbp");
if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) {
if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) {
num_shards_ = GlobalProcessCtx::WorldSize();
shard_index_ = GlobalProcessCtx::Rank();
} else {
......
......@@ -13,10 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -41,7 +41,7 @@ REGISTER_NO_GRAD_USER_OP("randperm")
Maybe<void> InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx) {
cfg::NdSbp* out = ctx->NdSbp4ArgNameAndIndex("out", 0);
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
if (JUST(IsMultiClient())) {
const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp");
NdSbp pb;
CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册