From 380d2414d2ebd45dbd04b9a22a3241098790aec3 Mon Sep 17 00:00:00 2001 From: Zhanghuihong Guan <31779698+Garfieldgzhh@users.noreply.github.com> Date: Mon, 1 Nov 2021 09:37:03 +0800 Subject: [PATCH] Change maybe to optional (#6611) * initial commit, add code for async construct tensor from numpy array * inital commit to change Maybe to Optional * delete redundant code * replace Maybe with Optional * fix compile errors * format code * changes based on review * format code, fix based on review * format code * fix multiclient type * changes based on review * changes based on review * unify calling to IsMultiClirnt * refector multi_client related code * restore InMultiClient interface * double check for unnecessary changes * remove unnecessary changes * format code * Update oneflow/api/python/symbol/job_conf_symbol.cpp * Update oneflow/api/python/symbol/op_conf_symbol.cpp * Update oneflow/api/python/symbol/op_node_signature_symbol.cpp * Update oneflow/core/common/optional.h * Update oneflow/api/python/symbol/string_symbol.cpp * Update oneflow/api/python/symbol/scope_symbol.cpp * Update oneflow/api/python/symbol/placement_symbol.cpp * Update oneflow/api/python/symbol/op_conf_symbol.cpp Co-authored-by: Houjiang Chen Co-authored-by: Twice --- oneflow/api/python/env/env.h | 18 ++------- oneflow/api/python/symbol/job_conf_symbol.cpp | 8 +++- oneflow/api/python/symbol/op_conf_symbol.cpp | 9 ++++- .../symbol/op_node_signature_symbol.cpp | 10 ++++- .../api/python/symbol/placement_symbol.cpp | 8 +++- oneflow/api/python/symbol/scope_symbol.cpp | 9 ++++- oneflow/api/python/symbol/string_symbol.cpp | 9 ++++- oneflow/core/common/multi_client.h | 39 +++++++++++++++++++ oneflow/core/device/cuda_util.cpp | 3 +- oneflow/core/eager/eager_oneflow.cpp | 4 +- .../core/framework/instructions_builder.cpp | 4 +- oneflow/core/framework/interpreter_test.cpp | 5 +-- .../multi_client_session_context.cpp | 5 ++- .../core/functional/impl/random_functor.cpp | 5 ++- oneflow/core/graph/task_graph.cpp | 4 +- oneflow/core/job/env_desc.cpp | 6 --- oneflow/core/job/env_desc.h | 2 - oneflow/core/job/global_for.cpp | 4 +- .../core/job/job_build_and_infer_ctx_mgr.cpp | 6 +-- oneflow/core/job/job_desc.cpp | 4 +- oneflow/core/job/job_desc.h | 5 ++- oneflow/core/job/parallel_desc.cpp | 3 +- oneflow/core/job/parallel_desc.h | 6 +-- oneflow/core/job/plan_util.cpp | 6 +-- oneflow/core/job/scope.cpp | 4 +- oneflow/core/job/scope.h | 5 ++- oneflow/core/job_rewriter/autotick.cpp | 13 ++++--- .../gradient_accumulation_rewrite_pass.cpp | 4 +- oneflow/core/job_rewriter/job_completer.cpp | 5 ++- .../set_default_variable_conf.cpp | 3 +- .../core/kernel/callback_notify_kernel.cpp | 2 +- oneflow/core/kernel/input_kernel.cpp | 2 +- oneflow/core/kernel/output_kernel.cpp | 4 +- oneflow/core/kernel/return_kernel.cpp | 4 +- .../core/kernel/wait_and_send_ids_kernel.cpp | 4 +- oneflow/core/operator/op_conf_symbol.h | 6 +-- .../core/operator/op_node_signature_desc.h | 4 +- oneflow/core/operator/output_op.cpp | 7 ++-- oneflow/core/rpc/lib/grpc.cpp | 4 +- oneflow/core/vm/id_generator.cpp | 8 ++-- oneflow/core/vm/oneflow_vm.cpp | 4 +- oneflow/core/vm/string_symbol.h | 6 +-- oneflow/core/vm/vm_util.cpp | 3 +- oneflow/user/data/coco_data_reader.cpp | 4 +- oneflow/user/data/ofrecord_dataset.h | 5 ++- .../user/kernels/gpt_data_loader_kernel.cpp | 6 +-- oneflow/user/ops/randperm_op.cpp | 6 +-- 47 files changed, 180 insertions(+), 115 deletions(-) create mode 100644 oneflow/core/common/multi_client.h diff --git a/oneflow/api/python/env/env.h b/oneflow/api/python/env/env.h index 9c467e28f3..09bac14cc8 100644 --- a/oneflow/api/python/env/env.h +++ b/oneflow/api/python/env/env.h @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/protobuf.h" -#include "oneflow/core/job/global_for.h" #include "oneflow/core/job/cluster.h" #include "oneflow/core/job/cluster_instruction.h" -#include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/env_global_objects_scope.h" +#include "oneflow/core/job/global_for.h" +#include "oneflow/core/job/resource_desc.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/rpc/include/base.h" @@ -45,19 +46,6 @@ inline Maybe EnableEagerEnvironment(bool enable_eager_execution) { return Maybe::Ok(); } -inline Maybe* IsMultiClientPtr() { return Global, MultiClient>::Get(); } - -inline Maybe IsMultiClient() { - CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr()); - return *IsMultiClientPtr(); -} - -inline Maybe SetIsMultiClient(bool is_multi_client) { - CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr()); - *IsMultiClientPtr() = is_multi_client; - return Maybe::Ok(); -} - inline Maybe IsEnvInited() { return Global::Get() != nullptr; } inline Maybe DestroyEnv() { diff --git a/oneflow/api/python/symbol/job_conf_symbol.cpp b/oneflow/api/python/symbol/job_conf_symbol.cpp index 47a5dbd311..e04200bbcf 100644 --- a/oneflow/api/python/symbol/job_conf_symbol.cpp +++ b/oneflow/api/python/symbol/job_conf_symbol.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_conf.cfg.h" @@ -36,7 +37,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { return CreateJobConfSymbol(symbol_id, symbol_conf).GetPtrOrThrow(); })) .def_property_readonly("symbol_id", - [](const JobDesc& x) { return x.symbol_id().GetOrThrow(); }) + [](const JobDesc& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def_property_readonly("data", &JobDesc::cfg_job_conf); } diff --git a/oneflow/api/python/symbol/op_conf_symbol.cpp b/oneflow/api/python/symbol/op_conf_symbol.cpp index 37b6b9c833..4701032e47 100644 --- a/oneflow/api/python/symbol/op_conf_symbol.cpp +++ b/oneflow/api/python/symbol/op_conf_symbol.cpp @@ -15,8 +15,10 @@ limitations under the License. */ #include #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/operator/op_conf_symbol.h" +#include "oneflow/core/common/maybe.h" namespace py = pybind11; @@ -25,7 +27,12 @@ namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_>(m, "OpConfSymbol") .def_property_readonly("symbol_id", - [](const OperatorConfSymbol& x) { return x.symbol_id().GetOrThrow(); }) + [](const OperatorConfSymbol& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def_property_readonly("data", &OperatorConfSymbol::data); } diff --git a/oneflow/api/python/symbol/op_node_signature_symbol.cpp b/oneflow/api/python/symbol/op_node_signature_symbol.cpp index d7dfb10a9f..a4cd53398c 100644 --- a/oneflow/api/python/symbol/op_node_signature_symbol.cpp +++ b/oneflow/api/python/symbol/op_node_signature_symbol.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/operator/op_node_signature_desc.h" #include "oneflow/core/operator/op_node_signature.pb.h" @@ -37,8 +38,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { py::init([](int64_t symbol_id, const std::shared_ptr& symbol_conf) { return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow(); })) - .def_property_readonly( - "symbol_id", [](const OpNodeSignatureDesc& x) { return x.symbol_id().GetOrThrow(); }) + .def_property_readonly("symbol_id", + [](const OpNodeSignatureDesc& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def("data", &OpNodeSignatureDesc::op_node_signature); } diff --git a/oneflow/api/python/symbol/placement_symbol.cpp b/oneflow/api/python/symbol/placement_symbol.cpp index eca558048e..d2b4b00e64 100644 --- a/oneflow/api/python/symbol/placement_symbol.cpp +++ b/oneflow/api/python/symbol/placement_symbol.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/framework/size.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/control/global_process_ctx.h" @@ -198,7 +199,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { hierarchy); })) .def_property_readonly("symbol_id", - [](const ParallelDesc& x) { return x.symbol_id().GetOrThrow(); }) + [](const ParallelDesc& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def_property_readonly("parallel_conf", &ParallelDesc::cfg_parallel_conf) .def_property_readonly("parallel_num", &ParallelDesc::parallel_num) .def_property_readonly("device_tag", &ParallelDesc::device_tag) diff --git a/oneflow/api/python/symbol/scope_symbol.cpp b/oneflow/api/python/symbol/scope_symbol.cpp index a41a041909..590fcc3e5f 100644 --- a/oneflow/api/python/symbol/scope_symbol.cpp +++ b/oneflow/api/python/symbol/scope_symbol.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/scope.h" @@ -36,7 +37,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def(py::init([](int64_t symbol_id, const std::shared_ptr& symbol_conf) { return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow(); })) - .def_property_readonly("symbol_id", [](const Scope& x) { return x.symbol_id().GetOrThrow(); }) + .def_property_readonly("symbol_id", + [](const Scope& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def_property_readonly("_proto_str", [](const Scope& x) { return PbMessage2TxtString(x.scope_proto()); }) .def("auto_increment_id", &Scope::auto_increment_id) diff --git a/oneflow/api/python/symbol/string_symbol.cpp b/oneflow/api/python/symbol/string_symbol.cpp index 99d15ce521..c4beffacee 100644 --- a/oneflow/api/python/symbol/string_symbol.cpp +++ b/oneflow/api/python/symbol/string_symbol.cpp @@ -14,7 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "oneflow/api/python/framework/throw.h" #include "oneflow/api/python/of_api_registry.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/vm/string_symbol.h" namespace py = pybind11; @@ -31,7 +33,12 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { return CreateStringSymbol(symbol_id, data).GetPtrOrThrow(); })) .def_property_readonly("symbol_id", - [](const StringSymbol& x) { return x.symbol_id().GetOrThrow(); }) + [](const StringSymbol& x) { + if (!x.symbol_id().has_value()) { + THROW(RuntimeError) << "symbol_id not initialized"; + } + return CHECK_JUST(x.symbol_id()); + }) .def_property_readonly("data", &StringSymbol::data); } diff --git a/oneflow/core/common/multi_client.h b/oneflow/core/common/multi_client.h new file mode 100644 index 0000000000..4ff5b0dc97 --- /dev/null +++ b/oneflow/core/common/multi_client.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_MULTICLIENT_H_ +#define ONEFLOW_CORE_COMMON_MULTICLIENT_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" +#include "oneflow/core/job/global_for.h" + +namespace oneflow { + +inline Optional* IsMultiClientPtr() { return Global, MultiClient>::Get(); } + +inline Maybe IsMultiClient() { + auto* opt = Global, MultiClient>::Get(); + return !opt || opt->value_or(true); +} + +inline Maybe SetIsMultiClient(bool is_multi_client) { + CHECK_NOTNULL_OR_RETURN(IsMultiClientPtr()); + *IsMultiClientPtr() = is_multi_client; + return Maybe::Ok(); +} +} // namespace oneflow + +#endif diff --git a/oneflow/core/device/cuda_util.cpp b/oneflow/core/device/cuda_util.cpp index b9650d6037..13f944123a 100644 --- a/oneflow/core/device/cuda_util.cpp +++ b/oneflow/core/device/cuda_util.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/global.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/device/node_device_descriptor_manager.h" #include "oneflow/core/device/cuda_device_descriptor.h" #include "oneflow/core/rpc/include/global_process_ctx.h" @@ -180,7 +181,7 @@ void CublasMathModeGuard::SetMathMode(cublasMath_t new_mode) { int GetCudaDeviceIndex() { int cuda_device_index = 0; - if (CHECK_JUST(GlobalMultiClientEnv())) { + if (CHECK_JUST(IsMultiClient())) { cuda_device_index = GlobalProcessCtx::LocalRank(); } else { OF_CUDA_CHECK(cudaGetDevice(&cuda_device_index)); diff --git a/oneflow/core/eager/eager_oneflow.cpp b/oneflow/core/eager/eager_oneflow.cpp index b2773ccdc8..8390a6b2ec 100644 --- a/oneflow/core/eager/eager_oneflow.cpp +++ b/oneflow/core/eager/eager_oneflow.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/eager/eager_oneflow.h" @@ -24,7 +25,6 @@ limitations under the License. #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/vm/string_symbol.h" #include "oneflow/core/eager/eager_symbol.cfg.h" -#include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/scope.h" #include "oneflow/core/job/cluster_instruction.h" @@ -94,7 +94,7 @@ Maybe EagerOneflow::RunPhysicalInstruction(vm::InstructionMsgList* instruc Maybe EagerOneflow::RunLogicalInstruction(vm::InstructionMsgList* instruction_list, const vm::cfg::EagerSymbolList& eager_symbol_list) { - if (JUST(GlobalMultiClientEnv())) { + if (JUST(IsMultiClient())) { // NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun, // because each rank will process instructions ONLY from itself, NOT the master. return RunPhysicalInstruction(instruction_list, eager_symbol_list); diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index d458498f18..d9721f963a 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/symbol_storage_util.h" #include "oneflow/core/eager/eager_symbol.cfg.h" @@ -42,7 +43,6 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instruction_replay.h" -#include "oneflow/core/job/env_desc.h" namespace oneflow { @@ -1750,7 +1750,7 @@ InstructionsBuilder::GetMut2OperandBlobObjects( } Maybe LogicalRun(const std::function(InstructionsBuilder*)>& Build) { - if (JUST(GlobalMultiClientEnv())) { + if (JUST(IsMultiClient())) { // NOTE(chengcheng): in Multi-Client LogicalRun will degenerate directly to PhysicalRun, // because each rank will process instructions ONLY from itself, NOT the master. return PhysicalRun(Build); diff --git a/oneflow/core/framework/interpreter_test.cpp b/oneflow/core/framework/interpreter_test.cpp index 20475df3e8..147fa0ed77 100644 --- a/oneflow/core/framework/interpreter_test.cpp +++ b/oneflow/core/framework/interpreter_test.cpp @@ -33,7 +33,7 @@ namespace { class TestVirtualMachineScope { public: TestVirtualMachineScope(int64_t gpu_device_num, int64_t cpu_device_num) { - *Global, MultiClient>::Get() = false; + *Global, MultiClient>::Get() = false; test_resource_desc_scope_.reset(new vm::TestResourceDescScope(gpu_device_num, cpu_device_num)); virtual_machine_scope_.reset( new vm::VirtualMachineScope(Global::Get()->resource())); @@ -42,8 +42,7 @@ class TestVirtualMachineScope { ~TestVirtualMachineScope() { virtual_machine_scope_.reset(); test_resource_desc_scope_.reset(); - Global, MultiClient>::SetAllocated( - new Maybe(Error::InvalidValueError("is_multi_client is not set"))); + Global, MultiClient>::SetAllocated(new Optional()); } private: diff --git a/oneflow/core/framework/multi_client_session_context.cpp b/oneflow/core/framework/multi_client_session_context.cpp index 4a71f9f0f0..54c325728a 100644 --- a/oneflow/core/framework/multi_client_session_context.cpp +++ b/oneflow/core/framework/multi_client_session_context.cpp @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/load_library.h" #include "oneflow/core/job/version.h" @@ -27,7 +29,6 @@ limitations under the License. #include "oneflow/core/memory/memory_allocator.h" #include "oneflow/core/register/register_manager.h" #include "oneflow/user/summary/events_writer.h" -#include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/memory/chunk_manager.h" #include "oneflow/core/vm/vm_util.h" @@ -55,7 +56,7 @@ int32_t GetGpuDeviceNum() { Maybe MultiClientSessionContext::TryInit(const ConfigProto& config_proto) { if (!is_inited_) { - CHECK_OR_RETURN(JUST(GlobalMultiClientEnv())); + CHECK_OR_RETURN(JUST(IsMultiClient())); DumpVersionInfo(); Resource resource = config_proto.resource(); diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp index d5d3ead90e..0072a10d6d 100644 --- a/oneflow/core/functional/impl/random_functor.cpp +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/global.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/attr_map.h" @@ -126,7 +127,7 @@ class ConsistentRandFunctor { const auto& distribution_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - if (!JUST(*Global, MultiClient>::Get())) { + if (!JUST(IsMultiClient())) { JUST(attrs.SetAttr("nd_sbp", nd_sbp->DebugString())); } auto result = JUST(OpInterpUtil::Dispatch( @@ -198,7 +199,7 @@ class ConsistentRandNFunctor { const auto& distribution_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - if (!JUST(*Global, MultiClient>::Get())) { + if (!JUST(IsMultiClient())) { JUST(attrs.SetAttr("nd_sbp", nd_sbp->DebugString())); } auto result = JUST(OpInterpUtil::Dispatch( diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 261ec4ffac..f5efe03966 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_graph.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/util.h" #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/graph/id_serialization.h" @@ -26,7 +27,6 @@ limitations under the License. #include "oneflow/core/job/scope.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/job_rewriter/calculation_pass.h" -#include "oneflow/core/job/env_desc.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/graph/stream_index_getter_registry_manager.h" @@ -544,7 +544,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_node } void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() { - if (!CHECK_JUST(GlobalMultiClientEnv())) { return; } + if (!CHECK_JUST(IsMultiClient())) { return; } HashMap rank_id2src_tick; HashMap rank_id2dst_tick; HashMap> rank_id2input_output_nodes; diff --git a/oneflow/core/job/env_desc.cpp b/oneflow/core/job/env_desc.cpp index 064e4e13c5..9fb9699b4f 100644 --- a/oneflow/core/job/env_desc.cpp +++ b/oneflow/core/job/env_desc.cpp @@ -50,10 +50,4 @@ int64_t EnvDesc::GetMachineId(const std::string& addr) const { return machine_id; } -Maybe GlobalMultiClientEnv() { - Maybe* is_multi_client = Global, MultiClient>::Get(); - CHECK_NOTNULL_OR_RETURN(is_multi_client); - return *is_multi_client; -} - } // namespace oneflow diff --git a/oneflow/core/job/env_desc.h b/oneflow/core/job/env_desc.h index 9389471c45..4733202aa0 100644 --- a/oneflow/core/job/env_desc.h +++ b/oneflow/core/job/env_desc.h @@ -44,8 +44,6 @@ class EnvDesc final { EnvProto env_proto_; }; -Maybe GlobalMultiClientEnv(); - } // namespace oneflow #endif // ONEFLOW_CORE_JOB_CLUSTER_DESC_H_ diff --git a/oneflow/core/job/global_for.cpp b/oneflow/core/job/global_for.cpp index c666072016..93d4ff43fe 100644 --- a/oneflow/core/job/global_for.cpp +++ b/oneflow/core/job/global_for.cpp @@ -16,12 +16,12 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/error.h" #include "oneflow/core/common/global.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/common/util.h" namespace oneflow { COMMAND(Global::SetAllocated(new bool(false))); -COMMAND(Global, MultiClient>::SetAllocated( - new Maybe(Error::InvalidValueError("is_multi_client is not set")))); +COMMAND(Global, MultiClient>::SetAllocated(new Optional())); } // namespace oneflow diff --git a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp index 2a94cda9c9..3e443c1fd7 100644 --- a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp @@ -15,10 +15,10 @@ limitations under the License. */ #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/lazy_mode.h" -#include "oneflow/core/job/env_desc.h" -#include "oneflow/core/common/util.h" #include namespace oneflow { @@ -106,7 +106,7 @@ Maybe EagerJobBuildAndInferCtxMgr::VirtualCloseJob() { bool EagerExecutionEnabled() { return *Global::Get(); } Maybe GlobalJobBuildAndInferCtxMgr() { - if (JUST(GlobalMultiClientEnv())) { + if (JUST(IsMultiClient())) { return JUST(GlobalMaybe()); } else { // single-client diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp index a8c67b5764..ff4b0c9b1d 100644 --- a/oneflow/core/job/job_desc.cpp +++ b/oneflow/core/job/job_desc.cpp @@ -43,14 +43,14 @@ void CheckFunctionConfig(const JobConfigProto& job_conf) { } // namespace JobDesc::JobDesc(const JobConfigProto& job_conf, int64_t job_id) - : job_conf_(job_conf), job_id_(job_id), symbol_id_(Error::SymbolIdUninitializedError()) { + : job_conf_(job_conf), job_id_(job_id), symbol_id_(NullOpt) { CHECK_JUST(Init()); Global::Get()->DumpCudnnConf(job_conf); } Maybe JobDesc::New(int64_t symbol_id, const JobConfigProto& job_conf) { auto job_desc = std::make_shared(job_conf); - job_desc->symbol_id_ = Maybe(symbol_id); + job_desc->symbol_id_ = symbol_id; return job_desc; } diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index eeb4621956..6fcccdb2f2 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_JOB_JOB_DESC_H_ #define ONEFLOW_CORE_JOB_JOB_DESC_H_ +#include "oneflow/core/common/optional.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/dlnet_conf.pb.h" #include "oneflow/core/job/job.pb.h" @@ -41,7 +42,7 @@ class JobDesc final { ~JobDesc() = default; static Maybe New(int64_t symbol_id, const JobConfigProto& job_conf); - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } const std::shared_ptr& cfg_job_conf() const { return cfg_job_conf_; } // Common @@ -84,7 +85,7 @@ class JobDesc final { JobConfigProto job_conf_; int64_t job_id_; - Maybe symbol_id_; + Optional symbol_id_; // merge job_conf_ and cfg_job_conf_ after cfg::JobConfigProto taken as a constructor argument std::shared_ptr cfg_job_conf_; }; diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index a256ebaf75..099560f95d 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -71,8 +71,7 @@ Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf) { return machine2device_list; } -ParallelDesc::ParallelDesc(const ParallelConf& user_conf) - : symbol_id_(Error::SymbolIdUninitializedError()) { +ParallelDesc::ParallelDesc(const ParallelConf& user_conf) : symbol_id_(NullOpt) { // NOLINT CHECK_JUST(MaybeInit(user_conf)); CHECK_JUST(CheckWithResourceDesc(*(Global::Get()))); } diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 44b1f4656b..47de7a7b80 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -58,7 +58,7 @@ class ParallelDesc final { Maybe MaybeInit(const ParallelConf& user_conf); // Getters - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } bool containing_current_rank() const { return containing_current_rank_; } DeviceType device_type() const { return device_type_; } const std::string& device_tag() const { return parallel_conf_.device_tag(); } @@ -108,7 +108,7 @@ class ParallelDesc final { private: friend Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf); - ParallelDesc() : symbol_id_(Error::SymbolIdUninitializedError()) {} + ParallelDesc() : symbol_id_(NullOpt) {} ParallelDesc(int64_t symbol_id) : symbol_id_(symbol_id) {} void ClearUp(); Maybe SetMachineIdAndDeviceIdsByParsingDeviceName(const std::string& device_name, @@ -117,7 +117,7 @@ class ParallelDesc final { Maybe CheckWithResourceDesc(const ResourceDesc& resource_desc); bool EqualsMachineId2SortedDevPhyIds(const ParallelDesc& rhs) const; - Maybe symbol_id_; + Optional symbol_id_; DeviceType device_type_; ParallelConf parallel_conf_; std::shared_ptr hierarchy_; diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 031e5305e8..917ae60961 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/constant.h" +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/str_util.h" #include "oneflow/core/job/plan_util.h" -#include "oneflow/core/job/env_desc.h" #include "oneflow/core/job/global_for.h" -#include "oneflow/core/common/str_util.h" #include "oneflow/core/graph/plan_task_graph.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" #include "oneflow/core/memory/chunk_manager.h" @@ -325,7 +325,7 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( } } - if (CHECK_JUST(GlobalMultiClientEnv())) { + if (CHECK_JUST(IsMultiClient())) { GenChunkForMultiNNGraphMemoryReuseInMultiClient(plan, &mem_block_id2mem_block); } else { CHECK(variable_op_names.empty()); diff --git a/oneflow/core/job/scope.cpp b/oneflow/core/job/scope.cpp index f1bde3e6a2..a26d0d53df 100644 --- a/oneflow/core/job/scope.cpp +++ b/oneflow/core/job/scope.cpp @@ -26,9 +26,7 @@ limitations under the License. namespace oneflow { Scope::Scope(const ScopeProto& scope_proto) - : auto_increment_id_(0), - symbol_id_(Error::SymbolIdUninitializedError()), - scope_proto_(scope_proto) { + : auto_increment_id_(0), symbol_id_(NullOpt), scope_proto_(scope_proto) { CHECK_OK(Init()) << scope_proto_.DebugString(); } diff --git a/oneflow/core/job/scope.h b/oneflow/core/job/scope.h index 24dc31c2a2..c3b82f3ae4 100644 --- a/oneflow/core/job/scope.h +++ b/oneflow/core/job/scope.h @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/job/job_desc.h" #include "oneflow/core/framework/attr_value.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/common/symbol.h" namespace oneflow { @@ -40,7 +41,7 @@ class Scope final { ~Scope() = default; static Maybe New(int64_t symbol_id, const ScopeProto& scope_proto); - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } int64_t auto_increment_id() { return ++auto_increment_id_; } int64_t session_id() const { return scope_proto().session_id(); } const std::shared_ptr& job_desc_symbol() const { return job_desc_; } @@ -78,7 +79,7 @@ class Scope final { const AttrValue& GetAttrValue(const std::string& attr_name) const; int64_t auto_increment_id_; - Maybe symbol_id_; + Optional symbol_id_; const ScopeProto scope_proto_; std::shared_ptr job_desc_; Symbol placement_scope_; diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp index efcf918abd..07b5546708 100644 --- a/oneflow/core/job_rewriter/autotick.cpp +++ b/oneflow/core/job_rewriter/autotick.cpp @@ -13,11 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/protobuf.h" #include "oneflow/core/job_rewriter/autotick.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/critical_section_desc.h" -#include "oneflow/core/common/protobuf.h" -#include "oneflow/core/common/container_util.h" #include "oneflow/core/job/global_for.h" namespace oneflow { @@ -562,7 +563,7 @@ Maybe AutoSourceAndSinkTick( } Maybe SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { - if (JUST(*Global, MultiClient>::Get())) { return Maybe::Ok(); } + if (JUST(IsMultiClient())) { return Maybe::Ok(); } auto* critical_section = Global::Get()->AddCriticalSection(GlobalJobDesc().job_id()); critical_section->mutable_total_job_critical_section(); @@ -581,7 +582,7 @@ Maybe SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilde } Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) { - if (!JUST(*Global, MultiClient>::Get())) { return Maybe::Ok(); } + if (!JUST(IsMultiClient())) { return Maybe::Ok(); } HashMap machine_id2src_op_name; HashMap machine_id2sink_op_name; { @@ -610,7 +611,7 @@ Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) Maybe SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { - if (JUST(*Global, MultiClient>::Get())) { return Maybe::Ok(); } + if (JUST(IsMultiClient())) { return Maybe::Ok(); } JUST(ForEachInputCriticalSectionOpNodes( op_graph, [&](const HashSet& op_nodes, @@ -623,7 +624,7 @@ Maybe SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, Maybe SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { - if (JUST(*Global, MultiClient>::Get())) { return Maybe::Ok(); } + if (JUST(IsMultiClient())) { return Maybe::Ok(); } JUST(ForEachOutputCriticalSectionOpNodes( op_graph, [&](const HashSet& op_nodes, diff --git a/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp b/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp index a83ca7e9c3..dad2f73cf9 100644 --- a/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp +++ b/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/core/job/env_desc.h" namespace oneflow { @@ -34,7 +34,7 @@ Maybe GradientAccumulationRewritePass::Apply(Job* job, JobPassCtx* ctx) co || job_conf.num_gradient_accumulation_steps() <= 1) { return Maybe::Ok(); } - const bool is_multi_client = CHECK_JUST(GlobalMultiClientEnv()); + const bool is_multi_client = CHECK_JUST(IsMultiClient()); const OpGraph op_graph(*job); JobBuilder job_builder(job); HashMap name2op_conf; diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index 974fa4a610..23755a5805 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/autograd.h" @@ -35,7 +36,7 @@ Maybe CheckOpGraph(const OpGraph& op_graph) { // NOTE(chengcheng): // in single-client source op is SourceTickOpConf, // in multi-client source op is WaitAndSendIdsOpConf_ - if (JUST(*Global, MultiClient>::Get())) { + if (JUST(IsMultiClient())) { CHECK_OR_RETURN(op_node->op().op_conf().has_wait_and_send_ids_conf()); } else { CHECK_OR_RETURN(op_node->op().op_conf().has_source_tick_conf()); @@ -49,7 +50,7 @@ Maybe CheckOpGraph(const OpGraph& op_graph) { // NOTE(chengcheng): // in single-client source op is SinkTickOpConf, // in multi-client source op is CallbackNotifyOpConf. - if (JUST(*Global, MultiClient>::Get())) { + if (JUST(IsMultiClient())) { CHECK_OR_RETURN(op_node->op().op_conf().has_callback_notify_conf()); } else { CHECK_OR_RETURN(op_node->op().op_conf().has_sink_tick_conf()); diff --git a/oneflow/core/job_rewriter/set_default_variable_conf.cpp b/oneflow/core/job_rewriter/set_default_variable_conf.cpp index 1b6a9f2030..d5e2b329f4 100644 --- a/oneflow/core/job_rewriter/set_default_variable_conf.cpp +++ b/oneflow/core/job_rewriter/set_default_variable_conf.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/job_builder.h" @@ -31,7 +32,7 @@ class SetDefaultVariableConf final : public JobPass { } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { - if (JUST(*Global, MultiClient>::Get())) { + if (JUST(IsMultiClient())) { // NOTE(chengcheng): Multi-Client Variable is inited by Eager. return Maybe::Ok(); } diff --git a/oneflow/core/kernel/callback_notify_kernel.cpp b/oneflow/core/kernel/callback_notify_kernel.cpp index 57830194a0..60a16b46ed 100644 --- a/oneflow/core/kernel/callback_notify_kernel.cpp +++ b/oneflow/core/kernel/callback_notify_kernel.cpp @@ -37,7 +37,7 @@ template void CallbackNotifyKernel::ForwardDataContent(KernelContext* ctx) const { auto* buffer_mgr = Global>>::Get(); std::string buffer_name; - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { CHECK(this->op_conf().callback_notify_conf().has_job_name()); buffer_name = GetCallbackNotifierBufferName(this->op_conf().callback_notify_conf().job_name()); } else { diff --git a/oneflow/core/kernel/input_kernel.cpp b/oneflow/core/kernel/input_kernel.cpp index df10278aac..12082d82db 100644 --- a/oneflow/core/kernel/input_kernel.cpp +++ b/oneflow/core/kernel/input_kernel.cpp @@ -31,7 +31,7 @@ class InputKernel final : public Kernel { private: void ForwardDataContent(KernelContext* ctx) const override { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { CHECK(this->op_conf().input_conf().has_job_name()); const auto& job_name = this->op_conf().input_conf().job_name(); const auto& op_name = this->op_conf().name(); diff --git a/oneflow/core/kernel/output_kernel.cpp b/oneflow/core/kernel/output_kernel.cpp index 819bb10e8a..f9533618e1 100644 --- a/oneflow/core/kernel/output_kernel.cpp +++ b/oneflow/core/kernel/output_kernel.cpp @@ -32,7 +32,7 @@ class OutputKernel final : public Kernel { }; void OutputKernel::ForwardDataContent(KernelContext* ctx) const { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { CHECK(this->op_conf().output_conf().has_job_name()); const auto& job_name = this->op_conf().output_conf().job_name(); const auto& op_name = this->op_conf().name(); @@ -51,7 +51,7 @@ void OutputKernel::ForwardDataContent(KernelContext* ctx) const { } void OutputKernel::ForwardHeader(KernelContext* ctx) const { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { // Do nothing. } else { ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->BnInOp2Blob("in")); diff --git a/oneflow/core/kernel/return_kernel.cpp b/oneflow/core/kernel/return_kernel.cpp index 87d3a4baa1..3f50afc5fb 100644 --- a/oneflow/core/kernel/return_kernel.cpp +++ b/oneflow/core/kernel/return_kernel.cpp @@ -32,7 +32,7 @@ class ReturnKernel final : public Kernel { }; void ReturnKernel::ForwardDataContent(KernelContext* ctx) const { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { CHECK(this->op_conf().return_conf().has_job_name()); const auto& job_name = this->op_conf().return_conf().job_name(); const auto& op_name = this->op_conf().name(); @@ -52,7 +52,7 @@ void ReturnKernel::ForwardDataContent(KernelContext* ctx) const { } void ReturnKernel::ForwardHeader(KernelContext* ctx) const { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { // Do nothing. } else { ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->BnInOp2Blob("in")); diff --git a/oneflow/core/kernel/wait_and_send_ids_kernel.cpp b/oneflow/core/kernel/wait_and_send_ids_kernel.cpp index c3362d3174..aa31649a96 100644 --- a/oneflow/core/kernel/wait_and_send_ids_kernel.cpp +++ b/oneflow/core/kernel/wait_and_send_ids_kernel.cpp @@ -31,7 +31,7 @@ void WaitAndSendIdsKernel::ForwardDataContent(KernelContext* ctx) const { auto* status = CHECK_NOTNULL(dynamic_cast(ctx->state().get())); const auto& conf = this->op_conf().wait_and_send_ids_conf(); if (status->out_idx_ >= status->out_num_) { - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { CHECK(this->op_conf().wait_and_send_ids_conf().has_job_name()); const auto& job_name = this->op_conf().wait_and_send_ids_conf().job_name(); auto* buffer_mgr = Global>>::Get(); @@ -53,7 +53,7 @@ void WaitAndSendIdsKernel::ForwardDataContent(KernelContext* ctx) const { } } - if (CHECK_JUST(*Global, MultiClient>::Get())) { + if (CHECK_JUST(*Global, MultiClient>::Get())) { *ctx->BnInOp2Blob("out")->mut_dptr() = 0; } else { *ctx->BnInOp2Blob("out")->mut_dptr() = conf.id_list(status->in_id_).value(status->out_idx_); diff --git a/oneflow/core/operator/op_conf_symbol.h b/oneflow/core/operator/op_conf_symbol.h index 816a7cb044..0fc3a4f06d 100644 --- a/oneflow/core/operator/op_conf_symbol.h +++ b/oneflow/core/operator/op_conf_symbol.h @@ -17,7 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_OPERATOR_OP_CONF_SYMBOL_H_ #include -#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/op_conf.cfg.h" @@ -32,11 +32,11 @@ class OperatorConfSymbol final { ~OperatorConfSymbol() = default; const OperatorConf& op_conf() const { return op_conf_; } - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } std::shared_ptr data() const { return data_; } private: - Maybe symbol_id_; + Optional symbol_id_; OperatorConf op_conf_; std::shared_ptr data_; }; diff --git a/oneflow/core/operator/op_node_signature_desc.h b/oneflow/core/operator/op_node_signature_desc.h index fb5ef0964d..6fffbe54f3 100644 --- a/oneflow/core/operator/op_node_signature_desc.h +++ b/oneflow/core/operator/op_node_signature_desc.h @@ -31,7 +31,7 @@ class OpNodeSignatureDesc final { OpNodeSignatureDesc(OpNodeSignatureDesc&&) = delete; OpNodeSignatureDesc(int64_t symbol_id, const OpNodeSignature& op_node_signature); - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } const std::shared_ptr& op_node_signature() const { return op_node_signature_; } @@ -43,7 +43,7 @@ class OpNodeSignatureDesc final { Maybe LogicalBlobDesc4BnInOp(const std::string& bn_in_op) const; private: - Maybe symbol_id_; + Optional symbol_id_; std::shared_ptr op_node_signature_; HashMap> bn_in_op2blob_desc_; }; diff --git a/oneflow/core/operator/output_op.cpp b/oneflow/core/operator/output_op.cpp index 6a36bbde13..0ccebaf9bd 100644 --- a/oneflow/core/operator/output_op.cpp +++ b/oneflow/core/operator/output_op.cpp @@ -13,10 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/operator/interface_op_util.h" #include "oneflow/core/operator/output_op.h" #include "oneflow/core/job/sbp_signature_builder.h" #include "oneflow/core/job/env_desc.h" -#include "oneflow/core/operator/interface_op_util.h" namespace oneflow { @@ -31,7 +32,7 @@ Maybe OutputOp::InferLogicalOutBlobDescs( const std::function& BlobDesc4BnInOp, const ParallelDesc& parallel_desc) const { BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); - if (CHECK_JUST(GlobalMultiClientEnv())) { + if (CHECK_JUST(IsMultiClient())) { *out_blob_desc = *BlobDesc4BnInOp("in"); } else { JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), @@ -45,7 +46,7 @@ Maybe OutputOp::InferOutBlobDescs( const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - if (CHECK_JUST(GlobalMultiClientEnv())) { + if (CHECK_JUST(IsMultiClient())) { // NOTE(chengcheng): // In multi-client, in blob shape maybe changed and NOT equal with output_conf.blob_conf, // and the output op actually is return op (used in single-client) with NO blob conf. diff --git a/oneflow/core/rpc/lib/grpc.cpp b/oneflow/core/rpc/lib/grpc.cpp index 4e28321f11..9ed0edab22 100644 --- a/oneflow/core/rpc/lib/grpc.cpp +++ b/oneflow/core/rpc/lib/grpc.cpp @@ -15,10 +15,10 @@ limitations under the License. */ #ifdef RPC_BACKEND_GRPC -#include "oneflow/core/rpc/include/grpc.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/ctrl_server.h" -#include "oneflow/core/job/env_desc.h" +#include "oneflow/core/rpc/include/grpc.h" namespace oneflow { diff --git a/oneflow/core/vm/id_generator.cpp b/oneflow/core/vm/id_generator.cpp index ca8a9c1e85..e302f95e30 100644 --- a/oneflow/core/vm/id_generator.cpp +++ b/oneflow/core/vm/id_generator.cpp @@ -13,16 +13,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/vm/id_generator.h" #include "oneflow/core/vm/id_util.h" -#include "oneflow/core/control/global_process_ctx.h" -#include "oneflow/core/job/env_desc.h" namespace oneflow { namespace vm { Maybe LogicalIdGenerator::NewSymbolId() { - if (JUST(GlobalMultiClientEnv())) { + if (JUST(IsMultiClient())) { // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. return IdUtil::NewPhysicalSymbolId(GlobalProcessCtx::Rank()); @@ -32,7 +32,7 @@ Maybe LogicalIdGenerator::NewSymbolId() { } Maybe LogicalIdGenerator::NewObjectId() { - if (JUST(GlobalMultiClientEnv())) { + if (JUST(IsMultiClient())) { // NOTE(chengcheng): in Multi-Client LogicalIdGenerator will degenerate directly to // PhysicalIdGenerator, because each rank will generate id ONLY from itself, NOT the master. return IdUtil::NewPhysicalObjectId(GlobalProcessCtx::Rank()); diff --git a/oneflow/core/vm/oneflow_vm.cpp b/oneflow/core/vm/oneflow_vm.cpp index c1c9212a07..1d658b0576 100644 --- a/oneflow/core/vm/oneflow_vm.cpp +++ b/oneflow/core/vm/oneflow_vm.cpp @@ -40,7 +40,7 @@ Maybe ForEachThreadCtx(vm::VirtualMachine* vm, void GetSchedulerThreadInitializer(std::function* Initializer) { *Initializer = [&]() { - if (!CHECK_JUST(*Global, MultiClient>::Get())) { return; } + if (!CHECK_JUST(*Global, MultiClient>::Get())) { return; } CHECK_JUST(InitThisThreadUniqueConsistentId(kThreadConsistentIdScheduler, "scheduler")); }; } @@ -76,7 +76,7 @@ void GetWorkerThreadInitializer(intrusive::shared_ptr vm, stream_type_index2consistent_id[stream_type_index] = thread_consistent_id++; } *Initializer = [stream_type_index2consistent_id](vm::ThreadCtx* thread_ctx) { - if (!CHECK_JUST(*Global, MultiClient>::Get())) { return; } + if (!CHECK_JUST(*Global, MultiClient>::Get())) { return; } const auto& stream_type_index = GetStreamTypeIndex(thread_ctx); const auto& iter = stream_type_index2consistent_id.find(stream_type_index); if (iter != stream_type_index2consistent_id.end()) { diff --git a/oneflow/core/vm/string_symbol.h b/oneflow/core/vm/string_symbol.h index 20536c05e5..7a3eebb613 100644 --- a/oneflow/core/vm/string_symbol.h +++ b/oneflow/core/vm/string_symbol.h @@ -17,7 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_VM_STRING_DESC_H_ #include -#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" namespace oneflow { @@ -29,11 +29,11 @@ class StringSymbol final { ~StringSymbol() = default; - const Maybe& symbol_id() const { return symbol_id_; } + const Optional& symbol_id() const { return symbol_id_; } const std::string& data() const { return data_; } private: - Maybe symbol_id_; + Optional symbol_id_; std::string data_; }; diff --git a/oneflow/core/vm/vm_util.cpp b/oneflow/core/vm/vm_util.cpp index 0584671213..d3b2088dcf 100644 --- a/oneflow/core/vm/vm_util.cpp +++ b/oneflow/core/vm/vm_util.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/blocking_counter.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/cluster_instruction.h" @@ -42,7 +43,7 @@ Maybe Run(vm::InstructionMsgList* instr_msg_list) { Maybe ClusterSync() { Maybe (*Run)(const std::function(InstructionsBuilder*)>& Build) = - JUST(*Global, MultiClient>::Get()) ? &PhysicalRun : &LogicalRun; + JUST(IsMultiClient()) ? &PhysicalRun : &LogicalRun; BlockingCounter bc(1); JUST(Run([&bc](InstructionsBuilder* builder) -> Maybe { JUST(builder->ComputeGlobalFrontSeqBarrier()); diff --git a/oneflow/user/data/coco_data_reader.cpp b/oneflow/user/data/coco_data_reader.cpp index 0483edd0ee..899a173d6d 100644 --- a/oneflow/user/data/coco_data_reader.cpp +++ b/oneflow/user/data/coco_data_reader.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" #include "oneflow/user/data/coco_data_reader.h" #include "oneflow/user/data/coco_dataset.h" #include "oneflow/user/data/distributed_training_dataset.h" @@ -21,7 +22,6 @@ limitations under the License. #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/rpc/include/global_process_ctx.h" -#include "oneflow/core/job/env_desc.h" namespace oneflow { namespace data { @@ -37,7 +37,7 @@ COCODataReader::COCODataReader(user_op::KernelInitContext* ctx) : DataReaderAttr>("nd_sbp"); - if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) { + if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) { parallel_id = GlobalProcessCtx::Rank(); parallel_num = GlobalProcessCtx::WorldSize(); } else { diff --git a/oneflow/user/data/ofrecord_dataset.h b/oneflow/user/data/ofrecord_dataset.h index 6505ce34fb..f3e3c980cf 100644 --- a/oneflow/user/data/ofrecord_dataset.h +++ b/oneflow/user/data/ofrecord_dataset.h @@ -16,14 +16,15 @@ limitations under the License. #ifndef ONEFLOW_USER_DATA_OFRECORD_DATASET_H_ #define ONEFLOW_USER_DATA_OFRECORD_DATASET_H_ -#include "oneflow/user/data/dataset.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/framework/op_kernel.h" #include "oneflow/core/persistence/persistent_in_stream.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/job/env_desc.h" +#include "oneflow/user/data/dataset.h" namespace oneflow { namespace data { @@ -61,7 +62,7 @@ class OFRecordDataset final : public Dataset { auto nd_sbp_str_vec = ctx->Attr>("nd_sbp"); // NOTE(zwx): OFRecordDataset is not consistent since attr nd_sbp is empty, // we assume that it works in DDP - if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) { is_local = true; } + if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) { is_local = true; } } if (is_local) { parallel_id_ = GlobalProcessCtx::Rank(); diff --git a/oneflow/user/kernels/gpt_data_loader_kernel.cpp b/oneflow/user/kernels/gpt_data_loader_kernel.cpp index af8844cd90..b1379e47ec 100644 --- a/oneflow/user/kernels/gpt_data_loader_kernel.cpp +++ b/oneflow/user/kernels/gpt_data_loader_kernel.cpp @@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/data/gpt_dataset.h" -#include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/rpc/include/global_process_ctx.h" -#include "oneflow/core/job/env_desc.h" namespace oneflow { @@ -74,7 +74,7 @@ class GPTDataLoader final : public OpKernelState { // NOTE(zwx): GPTDataLoader is not consistent since attr nd_sbp is empty, // we assume that it works in DDP auto nd_sbp_str_vec = ctx->Attr>("nd_sbp"); - if (nd_sbp_str_vec.empty() && CHECK_JUST(GlobalMultiClientEnv())) { + if (nd_sbp_str_vec.empty() && CHECK_JUST(IsMultiClient())) { num_shards_ = GlobalProcessCtx::WorldSize(); shard_index_ = GlobalProcessCtx::Rank(); } else { diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 6696064331..4d1dc49a69 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -13,10 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "oneflow/core/framework/framework.h" -#include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/global.h" +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/global_for.h" namespace oneflow { @@ -41,7 +41,7 @@ REGISTER_NO_GRAD_USER_OP("randperm") Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx) { cfg::NdSbp* out = ctx->NdSbp4ArgNameAndIndex("out", 0); - if (JUST(*Global, MultiClient>::Get())) { + if (JUST(IsMultiClient())) { const auto& pb_str = ctx->user_op_conf().attr("nd_sbp"); NdSbp pb; CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb)); -- GitLab