diff --git a/examples/mnist/1_machine_2_gpu.resource b/examples/mnist/1_machine_2_gpu.resource index f3aff8820d040c53875e7ad68e1c471ce0409bfc..2f0e3efe8efba62a4e6f43f2ce6c7763243592e7 100644 --- a/examples/mnist/1_machine_2_gpu.resource +++ b/examples/mnist/1_machine_2_gpu.resource @@ -1,7 +1,7 @@ machine { addr: "127.0.0.1" port: 7099 - name: "first" + id: 0 } gpu_device_num: 2 diff --git a/examples/mnist/predict.other b/examples/mnist/predict.other index 2573b86d1fa38657ad54cdd68ad6c1e384259c2c..3567362399ebd92c55b0edfc7aef473be6e5dace 100644 --- a/examples/mnist/predict.other +++ b/examples/mnist/predict.other @@ -1,4 +1,8 @@ -globalfs_conf { +data_fs_conf { + localfs_conf { + } +} +snapshot_fs_conf { localfs_conf { } } diff --git a/examples/mnist/predict.placement b/examples/mnist/predict.placement index ef53618c92e73c2098e04599d6a5710ef35e7a77..fd3a7ea85ad603aa05af8e5f4327419a4f0077e8 100644 --- a/examples/mnist/predict.placement +++ b/examples/mnist/predict.placement @@ -5,7 +5,7 @@ placement_group { } parallel_conf { policy: kDataParallel - device_name: "first:cpu:0-1" + device_name: "0:cpu:0-1" } } @@ -25,6 +25,6 @@ placement_group { } parallel_conf { policy: kDataParallel - device_name: "first:gpu:0-1" + device_name: "0:gpu:0-1" } } diff --git a/examples/mnist/train.other b/examples/mnist/train.other index e680c8593f9841f8230f90fb53b5d9a3409b2e86..107eab323a6e06d70daa0a43cf90d426440a9771 100644 --- a/examples/mnist/train.other +++ b/examples/mnist/train.other @@ -1,4 +1,8 @@ -globalfs_conf { +data_fs_conf { + localfs_conf { + } +} +snapshot_fs_conf { localfs_conf { } } diff --git a/examples/mnist/train.placement b/examples/mnist/train.placement index 8f8cc002bfbc399759ae58002d841cd5dcdf9fbb..2a91392c48081bb16164f63ba340cce4a175343b 100644 --- a/examples/mnist/train.placement +++ b/examples/mnist/train.placement @@ -4,7 +4,7 @@ placement_group { } parallel_conf { policy: kDataParallel - device_name: "first:cpu:0-1" + device_name: "0:cpu:0-1" } } @@ -24,6 +24,6 @@ placement_group { } parallel_conf { policy: kDataParallel - device_name: "first:gpu:0-1" + device_name: "0:gpu:0-1" } } diff --git a/oneflow/core/common/util.cpp b/oneflow/core/common/util.cpp index 33bca071d95d01c06e59bf52fc71c353dbb2de9a..0b703abc301b8b8e849fa3bda5b7ec7ce8d92ca8 100644 --- a/oneflow/core/common/util.cpp +++ b/oneflow/core/common/util.cpp @@ -7,6 +7,8 @@ #include #endif +DEFINE_int64(this_machine_id, -1, ""); + namespace oneflow { #define DEFINE_ONEFLOW_STR2INT_CAST(dst_type, cast_func) \ @@ -86,4 +88,9 @@ size_t GetAvailableCpuMemSize() { return 0; } +std::string LogDir() { + std::string v = FLAGS_log_dir + "/" + std::to_string(FLAGS_this_machine_id); + return v; +} + } // namespace oneflow diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index 79ba7e5e2a1181b176649b16f418315f32cb2cf8..a87d073a4486312b0afd568764f4e50e2e3e4451 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -26,6 +26,7 @@ #include "oneflow/core/common/meta_util.hpp" DECLARE_string(log_dir); +DECLARE_int64(this_machine_id); namespace std { template @@ -121,10 +122,7 @@ inline std::string NewUniqueId() { return std::to_string(id++); } -inline const std::string& LogDir() { - static std::string v = FLAGS_log_dir; - return v; -} +std::string LogDir(); template void EraseIf(HashMap* hash_map, std::function::iterator)> cond) { diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index f295dc1b9985e71e5f081b7022f71c781d060ca2..7b8de5479a85748d8bf5f06c873ea9aabbe53983 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -4,7 +4,7 @@ #include #include "oneflow/core/common/str_util.h" #include "oneflow/core/graph/node.h" -#include "oneflow/core/persistence/persistent_out_stream.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" namespace oneflow { @@ -161,15 +161,13 @@ void Graph::ToDotWithStream(StreamT& out_stream) { template void Graph::ToDotWithFilePath(const std::string& file_path) { - std::string dir_name = Dirname(file_path); - if (!LocalFS()->IsDirectory(dir_name)) { LocalFS()->RecursivelyCreateDir(dir_name); } - PersistentOutStream out_stream(LocalFS(), file_path); - ToDotWithStream(out_stream); + auto log_stream = TeePersistentLogStream::Create(file_path); + ToDotWithStream(log_stream); } template void Graph::ToDotWithAutoFilePath() { - std::string file_path = LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot"; + std::string file_path = JoinPath("dot", TypeName(), NewUniqueId() + ".dot"); ToDotWithFilePath(file_path); } diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index 38894d9ee2192866c482398d609d62f481037837..e34899e1c3701c916950c64dc50e5b9e64e7d92b 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -330,7 +330,7 @@ void LogicalGraph::BuildLossPrintStruct() { std::shared_ptr loss_print_op = ConstructOp(loss_print_op_conf); ParallelConf loss_print_pr_conf; loss_print_pr_conf.set_policy(kDataParallel); - loss_print_pr_conf.add_device_name(Global::Get()->MachineName4MachineId(0) + ":cpu:1"); + loss_print_pr_conf.add_device_name("0:cpu:1"); LossPrintLogicalNode* loss_print_logical = NewNode(); loss_print_logical->mut_op_vec() = {loss_print_op}; loss_print_logical->mut_parallel_desc().reset(new ParallelDesc(loss_print_pr_conf)); @@ -363,8 +363,7 @@ void LogicalGraph::BuildAccuracyPrintStruct() { std::shared_ptr accuracy_print_op = ConstructOp(accuracy_print_op_conf); ParallelConf accuracy_print_pr_conf; accuracy_print_pr_conf.set_policy(kDataParallel); - accuracy_print_pr_conf.add_device_name(Global::Get()->MachineName4MachineId(0) - + ":cpu:1"); + accuracy_print_pr_conf.add_device_name("0:cpu:1"); AccuracyPrintLogicalNode* accuracy_print_logical = NewNode(); accuracy_print_logical->mut_op_vec() = {accuracy_print_op}; accuracy_print_logical->mut_parallel_desc().reset(new ParallelDesc(accuracy_print_pr_conf)); @@ -544,10 +543,13 @@ MdSaveLogicalNode* LogicalGraph::BuildMdSaveStruct(const ForwardLogicalNode* fw_ auto md_save_logical = NewNode(); md_save_logical->mut_op_vec() = {model_save_op}; auto md_save_pr_desc = new ParallelDesc(*(fw_logical->parallel_desc())); + md_save_pr_desc->set_device_type(DeviceType::kCPU); if (fw_logical->parallel_desc()->policy() == ParallelPolicy::kDataParallel) { md_save_pr_desc->RandomSelectOneDeviceAndRemoveTheOthers(); } - md_save_pr_desc->set_device_type(DeviceType::kCPU); + if (Global::Get()->write_snapshot_to_master()) { + md_save_pr_desc->UseCPUDevicesOnMaster(); + } md_save_logical->mut_parallel_desc().reset(md_save_pr_desc); Connect(need_save_logical, NewEdge(), md_save_logical); return md_save_logical; diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index dff291eefc72a5c5577bbab204b544e950f9a7be..1c651f142591a32dc9623f7e2f3d447df1be4ba3 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/job/compiler.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/device/cudnn_conv_ctx_cache.h" namespace oneflow { @@ -6,11 +7,11 @@ namespace oneflow { namespace { void ToDotFile(const Plan& plan, const std::string& filepath) { - PersistentOutStream out_stream(LocalFS(), filepath); - out_stream << "digraph {\n"; + auto log_stream = TeePersistentLogStream::Create(filepath); + log_stream << "digraph {\n"; HashSet regst_desc_ids; for (const TaskProto& task_proto : plan.task()) { - out_stream << "task" << std::to_string(task_proto.task_id()) << "[label=\"" + log_stream << "task" << std::to_string(task_proto.task_id()) << "[label=\"" << std::to_string(task_proto.task_id()) << "\\n" << std::to_string(task_proto.machine_id()) << ":" << std::to_string(task_proto.thrd_id()) << ":" @@ -23,23 +24,23 @@ void ToDotFile(const Plan& plan, const std::string& filepath) { } } for (const int64_t regst_task_id : regst_desc_ids) { - out_stream << "regst_desc" << std::to_string(regst_task_id) << "[label=\"" + log_stream << "regst_desc" << std::to_string(regst_task_id) << "[label=\"" << std::to_string(regst_task_id) << "\", shape=box];\n"; } for (const TaskProto& task_proto : plan.task()) { for (const auto& pair : task_proto.produced_regst_desc()) { - out_stream << "task" << std::to_string(task_proto.task_id()) << "->regst_desc" + log_stream << "task" << std::to_string(task_proto.task_id()) << "->regst_desc" << std::to_string(pair.second.regst_desc_id()) << "[label=\"" << pair.first << "\"];\n"; } for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) { - out_stream << "regst_desc" << std::to_string(regst_desc_id) << "->task" + log_stream << "regst_desc" << std::to_string(regst_desc_id) << "->task" << std::to_string(task_proto.task_id()) << "[label=\"" << pair.first << "\"];\n"; } } } - out_stream << "}\n"; + log_stream << "}\n"; } } // namespace @@ -120,7 +121,7 @@ Plan Compiler::DoCompile() { }); plan.set_total_mbn_num(total_mbn_num); GenNetTopo(&plan); - ToDotFile(plan, JoinPath(LogDir(), "/dot/plan.dot")); + ToDotFile(plan, "/dot/plan.dot"); #ifdef WITH_CUDA Global::Delete(); #endif diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 87324361d5a9759ce22f48c81e3922822160717b..91df5dc811d4448d333a6969dedeb96a5050023d 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -27,19 +27,22 @@ message PredictConf { message LocalFsConf { } +message NetworkFsConf { +} + message HdfsConf { required string namenode = 1; } -message GlobalFSConf { +message FileSystemConf { oneof fs_type { LocalFsConf localfs_conf = 1; - HdfsConf hdfs_conf = 2; + NetworkFsConf networkfs_conf = 2; + HdfsConf hdfs_conf = 3; } } message OtherConf { - required GlobalFSConf globalfs_conf = 1; required int64 piece_size = 2; required int32 data_part_num = 3; // piece_size % data_part_num = 0 @@ -64,6 +67,9 @@ message OtherConf { optional bool enable_write_snapshot = 130 [default = true]; optional bool enable_blob_mem_sharing = 140 [default = true]; + required FileSystemConf data_fs_conf = 121; + required FileSystemConf snapshot_fs_conf = 122; + oneof JobType { TrainConf train_conf = 200; PredictConf predict_conf = 201; diff --git a/oneflow/core/job/job_desc.cpp b/oneflow/core/job/job_desc.cpp index 0f81c8d2f71b5b0c5b1c71754a68528518af0f8a..ba3f91ca533373a2e6055deb56360cc118ea3900 100644 --- a/oneflow/core/job/job_desc.cpp +++ b/oneflow/core/job/job_desc.cpp @@ -5,15 +5,6 @@ namespace oneflow { -int64_t JobDesc::MachineID4MachineName(const std::string& machine_name) const { - auto it = machine_name2machine_id_.find(machine_name); - CHECK(it != machine_name2machine_id_.end()) << "Undefined machine name: " << machine_name; - return it->second; -} -const std::string& JobDesc::MachineName4MachineId(int64_t machine_id) const { - return machine_id2machine_name_.at(machine_id); -} - int64_t JobDesc::piece_num_of_experiment_phase() const { return job_conf_.other().piece_num_of_experiment_phase(); } @@ -87,6 +78,11 @@ float JobDesc::L2() const { int32_t JobDesc::DataPartNum() const { return job_conf_.other().data_part_num(); } +const FileSystemConf& JobDesc::data_fs_conf() const { return job_conf_.other().data_fs_conf(); } +const FileSystemConf& JobDesc::snapshot_fs_conf() const { + return job_conf_.other().snapshot_fs_conf(); +} + JobDesc::JobDesc(const std::string& job_conf_filepath) { if (TryParseProtoFromTextFile(job_conf_filepath, &job_conf_) == false) { JobConf2 job_conf; @@ -96,7 +92,7 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) { ParseProtoFromTextFile(job_conf.placement(), job_conf_.mutable_placement()); ParseProtoFromTextFile(job_conf.other(), job_conf_.mutable_other()); } - + SanityCheck(); SplitDecodeOps(); AddRecordLoadOps(); #ifndef WITH_RDMA @@ -123,12 +119,13 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) { #ifndef WITH_CUDA CHECK_EQ(job_conf_.resource().gpu_device_num(), 0); #endif +} + +void JobDesc::SanityCheck() { int64_t machine_num = job_conf_.resource().machine_size(); - for (int64_t i = 0; i < machine_num; ++i) { - const std::string& machine_name = job_conf_.resource().machine(i).name(); - CHECK(machine_name2machine_id_.emplace(machine_name, i).second); - CHECK(machine_id2machine_name_.emplace(i, machine_name).second); - } + FOR_RANGE(int64_t, i, 0, machine_num) { CHECK_EQ(job_conf_.resource().machine(i).id(), i); } + CHECK_GE(FLAGS_this_machine_id, 0); + CHECK_LT(FLAGS_this_machine_id, machine_num); } void JobDesc::SplitDecodeOps() { diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index 6f6a69c19fb6f6dca6bd31a0c7dea76b07c2929e..e1ced00cd63f85420e76168bf33624670d8335f6 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -47,16 +47,15 @@ class JobDesc final { size_t rdma_recv_msg_buf_byte() const; bool collect_act_event() const { return job_conf_.other().collect_act_event(); } bool enable_mem_sharing() const { return job_conf_.other().enable_mem_sharing(); } + const FileSystemConf& data_fs_conf() const; + const FileSystemConf& snapshot_fs_conf() const; bool enable_write_snapshot() const { return IsTrain() && job_conf_.other().enable_write_snapshot(); } + bool write_snapshot_to_master() const { return snapshot_fs_conf().has_localfs_conf(); } bool enable_blob_mem_sharing() const { return job_conf_.other().enable_blob_mem_sharing(); } int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); } - // machine_name <-> machine_id - int64_t MachineID4MachineName(const std::string& machine_name) const; - const std::string& MachineName4MachineId(int64_t machine_id) const; - // Train conf const std::string& MdSaveSnapshotsPath() const; int32_t NumOfBatchesInSnapshot() const; @@ -73,13 +72,11 @@ class JobDesc final { private: friend class Global; JobDesc(const std::string& job_conf_filepath); + void SanityCheck(); void SplitDecodeOps(); void AddRecordLoadOps(); JobConf1 job_conf_; - - HashMap machine_name2machine_id_; - HashMap machine_id2machine_name_; }; } // namespace oneflow diff --git a/oneflow/core/job/machine_context.cpp b/oneflow/core/job/machine_context.cpp index 235d917846b0c4e4efd01a5f32e76ce34836e349..97bf0062bfa1c1a39dd73055a8d9831bdb2be92d 100644 --- a/oneflow/core/job/machine_context.cpp +++ b/oneflow/core/job/machine_context.cpp @@ -7,9 +7,7 @@ std::string MachineCtx::GetCtrlAddr(int64_t machine_id) const { return mchn.addr() + ":" + std::to_string(mchn.port()); } -MachineCtx::MachineCtx(const std::string& this_mchn_name) { - this_machine_id_ = Global::Get()->MachineID4MachineName(this_mchn_name); - LOG(INFO) << "this machine name: " << this_mchn_name; +MachineCtx::MachineCtx(int64_t this_mchn_id) : this_machine_id_(this_mchn_id) { LOG(INFO) << "this machine id: " << this_machine_id_; } diff --git a/oneflow/core/job/machine_context.h b/oneflow/core/job/machine_context.h index cc5c8e6b7f0e3a20c51136b7021f5a0b3d8f2577..d08acc90ae6540dfaf67ade657f488e2acf65e1b 100644 --- a/oneflow/core/job/machine_context.h +++ b/oneflow/core/job/machine_context.h @@ -19,7 +19,7 @@ class MachineCtx final { private: friend class Global; - MachineCtx(const std::string& this_mchn_name); + MachineCtx(int64_t this_mchn_id); int64_t this_machine_id_; }; diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index b00ae6f206216f0545a3f172d5eb3f670f01c13e..8821a93b5f97fe1d85257ebf95294f0c2f93df62 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -10,6 +10,7 @@ #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/job/available_memory_desc.pb.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/persistence/file_system.h" #include "oneflow/core/actor/act_event_logger.h" @@ -150,10 +151,10 @@ bool HasRelayPlacement() { for (const PlacementGroup& p_group : placement.placement_group()) { const ParallelConf& p_conf = p_group.parallel_conf(); for (const std::string& device_name : p_conf.device_name()) { - std::string mchn_name; + int64_t mchn_id; std::string device_tag; std::string device_id_str; - ParseDeviceNameConf(device_name, &mchn_name, &device_tag, &device_id_str); + ParseDeviceNameConf(device_name, &mchn_id, &device_tag, &device_id_str); if (device_tag == "cpu") { break; } else if (device_tag == "gpu") { @@ -177,16 +178,16 @@ class Oneflow final { OF_DISALLOW_COPY_AND_MOVE(Oneflow); ~Oneflow() = default; - Oneflow(const std::string& job_conf_filepath, const std::string& this_mchn_name); + Oneflow(const std::string& job_conf_filepath, int64_t this_mchn_id); private: std::unique_ptr ctrl_server_; }; -Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_mchn_name) { +Oneflow::Oneflow(const std::string& job_conf_filepath, int64_t this_mchn_id) { // New All Global Global::New(job_conf_filepath); - Global::New(this_mchn_name); + Global::New(this_mchn_id); const MachineCtx* machine_ctx = Global::Get(); bool DoProfile = machine_ctx->IsThisMachineMaster() && Global::Get()->collect_act_event(); @@ -216,15 +217,15 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m PullPlan("mem_shared_plan", &mem_shared_plan); } OF_BARRIER(); - PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan")); - PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan")); + TeePersistentLogStream::Create("naive_plan")->Write(naive_plan); + TeePersistentLogStream::Create("mem_shared_plan")->Write(mem_shared_plan); LOG(INFO) << "push_pull_plan:" << GetCurTime() - start; if (HasRelayPlacement()) { // Experiment Runtime { Runtime experiment_run(mem_shared_plan, true); } // Improve if (machine_ctx->IsThisMachineMaster()) { - PrintProtoToTextFile(amd, JoinPath(LogDir(), "available_mem_desc")); + TeePersistentLogStream::Create("available_mem_desc")->Write(amd); CHECK_GT(amd.machine_amd_size(), 0); improved_plan = Improver().Improve( amd, naive_plan, JoinPath(LogDir(), ActEventLogger::experiment_act_event_bin_filename())); @@ -233,7 +234,7 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m PullPlan("improved_plan", &improved_plan); } OF_BARRIER(); - PrintProtoToTextFile(improved_plan, JoinPath(LogDir(), "improved_plan")); + TeePersistentLogStream::Create("improved_plan")->Write(improved_plan); Global::Get()->Clear(); OF_BARRIER(); } else { @@ -257,16 +258,16 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m } // namespace oneflow DEFINE_string(job_conf, "", ""); -DEFINE_string(this_machine_name, "", ""); int main(int argc, char** argv) { using namespace oneflow; google::InitGoogleLogging(argv[0]); gflags::SetVersionString(BuildVersionString()); gflags::ParseCommandLineFlags(&argc, &argv, true); + CHECK_GE(FLAGS_this_machine_id, 0); LocalFS()->RecursivelyCreateDirIfNotExist(LogDir()); RedirectStdoutAndStderrToGlogDir(); - { Oneflow flow(FLAGS_job_conf, FLAGS_this_machine_name); } + { Oneflow flow(FLAGS_job_conf, FLAGS_this_machine_id); } CloseStdoutAndStderr(); return 0; } diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 315a33d45e1d11affeb4791a12e6b982ad93cf59..e1c5a0ef7284a731d46bd6617fef15c6a2e9c945 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -1,14 +1,15 @@ #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/common/util.h" namespace oneflow { -void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, - std::string* device_tag, std::string* device_id_str) { +void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag, + std::string* device_id_str) { size_t second_delimiter_pos = device_name.rfind(":"); CHECK_NE(second_delimiter_pos, std::string::npos); size_t first_delimiter_pos = device_name.rfind(":", second_delimiter_pos - 1); CHECK_NE(first_delimiter_pos, std::string::npos); - *mchn_name = device_name.substr(0, first_delimiter_pos); + *mchn_id = oneflow_cast(device_name.substr(0, first_delimiter_pos)); *device_tag = device_name.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1); *device_id_str = device_name.substr(second_delimiter_pos + 1); @@ -16,14 +17,14 @@ void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { policy_ = user_conf.policy(); - HashSet machine_name_set; + HashSet machine_id_set; device_type_ = DeviceType::kInvalidDevice; for (const std::string& device_name : user_conf.device_name()) { - std::string mchn_name; + int64_t mchn_id; std::string device_tag; std::string device_id_str; - ParseDeviceNameConf(device_name, &mchn_name, &device_tag, &device_id_str); - machine_name_set.insert(mchn_name); + ParseDeviceNameConf(device_name, &mchn_id, &device_tag, &device_id_str); + machine_id_set.insert(mchn_id); if (device_tag == "cpu") { CHECK_STREQ(device_tag.c_str(), "cpu"); CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kCPU); @@ -33,8 +34,7 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kGPU); device_type_ = DeviceType::kGPU; } - int64_t machine_id = Global::Get()->MachineID4MachineName(mchn_name); - sorted_machine_ids_.push_back(machine_id); + sorted_machine_ids_.push_back(mchn_id); int64_t minus_pos = device_id_str.find("-"); if (minus_pos == std::string::npos) { device_id_str = device_id_str + "-" + device_id_str; @@ -47,11 +47,11 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { if (device_type_ == DeviceType::kGPU) { CHECK_LT(dev_phy_id, Global::Get()->GpuDeviceNum()); } - machine_id2sorted_dev_phy_ids_[machine_id].push_back(dev_phy_id); + machine_id2sorted_dev_phy_ids_[mchn_id].push_back(dev_phy_id); } } ClearUp(); - CheckValidity(); + SanityCheck(); } void ParallelDesc::RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num) { @@ -104,6 +104,13 @@ void ParallelDesc::RandomSelectOneDeviceAndRemoveTheOthers() { parallel_num_ = 1; } +void ParallelDesc::UseCPUDevicesOnMaster() { + sorted_machine_ids_ = {0}; + std::vector sorted_dev_phy_ids(parallel_num_); + std::iota(sorted_dev_phy_ids.begin(), sorted_dev_phy_ids.end(), 0); + machine_id2sorted_dev_phy_ids_ = {{0, sorted_dev_phy_ids}}; +} + bool ParallelDesc::Equal(const ParallelDesc& rhs) const { return device_type_ == rhs.device_type_ && policy_ == rhs.policy_ && sorted_machine_ids_ == rhs.sorted_machine_ids_ @@ -124,7 +131,7 @@ void ParallelDesc::ClearUp() { SortAndRemoveDuplication(&sorted_machine_ids_); } -void ParallelDesc::CheckValidity() { +void ParallelDesc::SanityCheck() { device_num_of_each_machine_ = -1; for (auto& pair : machine_id2sorted_dev_phy_ids_) { if (device_num_of_each_machine_ == -1) { diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 5042fc44c4c0ab24ad957fb6509d3f2ede2cf0bc..224d24816a9180ab0f672628e840e522148a6841 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -8,8 +8,8 @@ namespace oneflow { -void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, - std::string* device_tag, std::string* device_id_str); +void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag, + std::string* device_id_str); class ParallelDesc { public: @@ -35,6 +35,7 @@ class ParallelDesc { void RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num); void RemoveNeedlessDevice(int32_t max_device_num) { RemoveNeedlessDevice("", max_device_num); } void RandomSelectOneDeviceAndRemoveTheOthers(); + void UseCPUDevicesOnMaster(); // bool Equal(const ParallelDesc& rhs) const; @@ -42,7 +43,7 @@ class ParallelDesc { private: void ClearUp(); - void CheckValidity(); + void SanityCheck(); DeviceType device_type_; ParallelPolicy policy_; diff --git a/oneflow/core/job/profiler.cpp b/oneflow/core/job/profiler.cpp index 8b10ed028ad7ca446bc9b545b20b956ca7628b69..69df13b0e72e7341756aa63c3a3b044bb15ff8b2 100644 --- a/oneflow/core/job/profiler.cpp +++ b/oneflow/core/job/profiler.cpp @@ -1,6 +1,6 @@ #include "oneflow/core/job/profiler.h" #include "oneflow/core/job/job_desc.h" -#include "oneflow/core/persistence/persistent_out_stream.h" +#include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/actor/act_event_logger.h" @@ -84,7 +84,7 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath) [](const ProfileInfoPair& lhs, const ProfileInfoPair& rhs) { return lhs.second.CalcBottleNeckScore() > rhs.second.CalcBottleNeckScore(); }); - PersistentOutStream out_stream(LocalFS(), JoinPath(LogDir(), "oneflow.profile")); + auto log_stream = TeePersistentLogStream::Create("oneflow.profile"); double mdupdt_act_interval = 0.0; int32_t mdupdt_task_num = 0; for (const ProfileInfoPair& pair : profile_info_vec) { @@ -93,10 +93,10 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath) mdupdt_act_interval += pair.second.avg_act_interval(); } } - out_stream << "time_of_one_batch:" << std::to_string(mdupdt_act_interval / mdupdt_task_num) + log_stream << "time_of_one_batch:" << std::to_string(mdupdt_act_interval / mdupdt_task_num) << "\n"; for (const ProfileInfoPair& pair : profile_info_vec) { - out_stream << "actor_id:" << std::to_string(pair.first) + log_stream << "actor_id:" << std::to_string(pair.first) << " act_num: " << std::to_string(pair.second.act_num()) << " avg_act_time:" << std::to_string(pair.second.avg_act_time()) << " avg_act_interval:" << std::to_string(pair.second.avg_act_interval()) diff --git a/oneflow/core/job/resource.proto b/oneflow/core/job/resource.proto index ffb3117ae35983db1464e4fcc329b3d8435121e5..454b120a9a8dc80e9d2808e7f4773aec2ea087c2 100644 --- a/oneflow/core/job/resource.proto +++ b/oneflow/core/job/resource.proto @@ -4,7 +4,7 @@ package oneflow; message Machine { required string addr = 1; // domain name or ip required int32 port = 2; - required string name = 3; + required int64 id = 3; } enum DeviceType { diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 730e5ee8ee46dee46d6248ef51f1593de8426423..cd39d4d979be63ec88d39efb1c3b1ebac22b92b6 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -89,7 +89,8 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) { } } Global::New(piece_num, is_experiment_phase); - if (Global::Get()->NeedCollectActEvent()) { + if (Global::Get()->IsThisMachineMaster() + && Global::Get()->NeedCollectActEvent()) { Global::New(is_experiment_phase); } if (job_desc->TotalMachineNum() > 1) { diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp index 4ebf58e84e83935d73325b739ae9b457bfb904fe..af5140207e46bbb7d7c4ef2f41d23fd9a5ac01d5 100644 --- a/oneflow/core/kernel/kernel_util.cpp +++ b/oneflow/core/kernel/kernel_util.cpp @@ -263,11 +263,11 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num int64_t blob_size = blob->ByteSizeOfDataContentField(); int64_t byte_size_of_each_dim = num_in_each_dim * sizeof(T); std::string file_path = JoinPath(model_dir, bn_in_op); - uint64_t file_size = GlobalFS()->GetFileSize(file_path); + uint64_t file_size = SnapshotFS()->GetFileSize(file_path); CHECK_EQ(file_size, dim_num * byte_size_of_each_dim); BalancedSplitter splitter = BalancedSplitter(dim_num, part_num); int64_t begin_pos = splitter.At(part_id).begin() * byte_size_of_each_dim; - PersistentInStream in_stream(GlobalFS(), file_path, begin_pos); + PersistentInStream in_stream(SnapshotFS(), file_path, begin_pos); in_stream.Read(blob->mut_dptr(), blob_size); } diff --git a/oneflow/core/kernel/print_kernel.cpp b/oneflow/core/kernel/print_kernel.cpp index 8d505a499e42c3e5aa540b619d7a4732dedd488e..607c314251ab00c8cf68119b6bec689aae5150ac 100644 --- a/oneflow/core/kernel/print_kernel.cpp +++ b/oneflow/core/kernel/print_kernel.cpp @@ -7,13 +7,13 @@ namespace oneflow { void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { const auto& conf = op_conf().print_conf(); const std::string& root_path = conf.print_dir(); - OfCallOnce(root_path, GlobalFS(), &fs::FileSystem::RecursivelyCreateDir); + OfCallOnce(root_path, SnapshotFS(), &fs::FileSystem::RecursivelyCreateDir); int32_t part_name_suffix_length = conf.part_name_suffix_length(); std::string num = std::to_string(parallel_ctx->parallel_id()); int32_t zero_count = std::max(part_name_suffix_length - static_cast(num.length()), 0); std::string file_path = JoinPath(root_path, conf.part_name_prefix() + std::string(zero_count, '0') + num); - out_stream_.reset(new PersistentOutStream(GlobalFS(), file_path)); + out_stream_.reset(new PersistentOutStream(SnapshotFS(), file_path)); } void PrintKernel::Forward(const KernelCtx& ctx, diff --git a/oneflow/core/kernel/record_load_kernel.cpp b/oneflow/core/kernel/record_load_kernel.cpp index d5c1710c81ef614d8447abeafca5222b915f977a..600c5006d233f21e758cc7addf620d218ea36826 100644 --- a/oneflow/core/kernel/record_load_kernel.cpp +++ b/oneflow/core/kernel/record_load_kernel.cpp @@ -22,13 +22,10 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { data_paths.push_back(JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num)); } if (Global::Get()->IsTrain()) { - if (Global::Get()->save_downloaded_file_to_local_fs() && GlobalFS() != LocalFS()) { - in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, true, true)); - } else { - in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, true, false)); - } + in_stream_.reset(new PersistentInStream( + DataFS(), data_paths, true, Global::Get()->save_downloaded_file_to_local_fs())); } else { - in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, false, false)); + in_stream_.reset(new PersistentInStream(DataFS(), data_paths, false, false)); } int64_t global_piece_size = Global::Get()->PieceSize(); CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0); diff --git a/oneflow/core/persistence/file_system.cpp b/oneflow/core/persistence/file_system.cpp index 96edff75b3f41740af694cd297338bcaa19acadc..48eaaff199f34fad7f8acc87860836cb67d8af2c 100644 --- a/oneflow/core/persistence/file_system.cpp +++ b/oneflow/core/persistence/file_system.cpp @@ -90,21 +90,6 @@ void FileSystem::RecursivelyCreateDir(const std::string& dirname) { } } -struct GlobalFSConstructor { - GlobalFSConstructor() { - const GlobalFSConf& gfs_conf = Global::Get()->other_conf().globalfs_conf(); - if (gfs_conf.has_localfs_conf()) { - CHECK_EQ(Global::Get()->resource().machine().size(), 1); - gfs = LocalFS(); - } else if (gfs_conf.has_hdfs_conf()) { - gfs = new HadoopFileSystem(gfs_conf.hdfs_conf()); - } else { - UNIMPLEMENTED(); - } - } - FileSystem* gfs; -}; - } // namespace fs fs::FileSystem* LocalFS() { @@ -116,9 +101,25 @@ fs::FileSystem* LocalFS() { return fs; } -fs::FileSystem* GlobalFS() { - static fs::GlobalFSConstructor gfs_constructor; - return gfs_constructor.gfs; +fs::FileSystem* NetworkFS() { return LocalFS(); } + +fs::FileSystem* HadoopFS(const HdfsConf& hdfs_conf) { + static fs::FileSystem* fs = new fs::HadoopFileSystem(hdfs_conf); + return fs; +} + +fs::FileSystem* GetFS(const FileSystemConf& file_system_conf) { + if (file_system_conf.has_localfs_conf()) { + return LocalFS(); + } else if (file_system_conf.has_networkfs_conf()) { + return NetworkFS(); + } else if (file_system_conf.has_hdfs_conf()) { + return HadoopFS(file_system_conf.hdfs_conf()); + } else { + UNIMPLEMENTED(); + } } +fs::FileSystem* DataFS() { return GetFS(Global::Get()->data_fs_conf()); } +fs::FileSystem* SnapshotFS() { return GetFS(Global::Get()->snapshot_fs_conf()); } } // namespace oneflow diff --git a/oneflow/core/persistence/file_system.h b/oneflow/core/persistence/file_system.h index 6644f7c7134248cb09f0dfc2526586058c152982..3ef591c06caa8d08a4a0501b001fd659004a7109 100644 --- a/oneflow/core/persistence/file_system.h +++ b/oneflow/core/persistence/file_system.h @@ -3,6 +3,7 @@ #include "oneflow/core/common/platform.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/job/job_conf.pb.h" namespace oneflow { @@ -153,8 +154,10 @@ class FileSystem { } // namespace fs fs::FileSystem* LocalFS(); -fs::FileSystem* GlobalFS(); +fs::FileSystem* GetFS(const FileSystemConf& file_system_conf); +fs::FileSystem* DataFS(); +fs::FileSystem* SnapshotFS(); } // namespace oneflow #endif // ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_ diff --git a/oneflow/core/persistence/snapshot.cpp b/oneflow/core/persistence/snapshot.cpp index b8ab84ca3d2c02382f4757ee68ed399591ee03b0..1a8c44e2fa2d8ac50626827caceddfba6ea8d25d 100644 --- a/oneflow/core/persistence/snapshot.cpp +++ b/oneflow/core/persistence/snapshot.cpp @@ -6,7 +6,7 @@ namespace oneflow { Snapshot::Snapshot(const std::string& snapshot_root_path) { - CHECK(GlobalFS()->IsDirectory(snapshot_root_path)); + CHECK(SnapshotFS()->IsDirectory(snapshot_root_path)); root_path_ = snapshot_root_path; } @@ -18,13 +18,13 @@ std::unique_ptr Snapshot::GetOutStream(const LogicalBlobId& int32_t part_id) { // op_name_dir std::string op_name_dir = JoinPath(root_path_, lbi.op_name()); - OfCallOnce(op_name_dir, GlobalFS(), &fs::FileSystem::CreateDir); + OfCallOnce(op_name_dir, SnapshotFS(), &fs::FileSystem::CreateDir); // bn_in_op_tmp_dir std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58"); - OfCallOnce(bn_in_op_tmp_dir, GlobalFS(), &fs::FileSystem::CreateDir); + OfCallOnce(bn_in_op_tmp_dir, SnapshotFS(), &fs::FileSystem::CreateDir); // part_file std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id)); - return std::make_unique(GlobalFS(), part_file); + return std::make_unique(SnapshotFS(), part_file); } void Snapshot::OnePartDone(const LogicalBlobId& lbi, int32_t part_id, int32_t part_num) { @@ -41,12 +41,12 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num, std::vector buffer(Global::Get()->persistence_buf_byte()); std::string part_dir = JoinPath(root_path_, lbi.op_name(), lbi.blob_name() + "_tmp4a58"); { - PersistentOutStream out_stream(GlobalFS(), concat_file); + PersistentOutStream out_stream(SnapshotFS(), concat_file); for (int32_t i = 0; i < part_num; ++i) { std::unique_ptr part_file; std::string part_file_path = JoinPath(part_dir, "part_" + std::to_string(i)); - GlobalFS()->NewRandomAccessFile(part_file_path, &part_file); - uint64_t part_file_size = GlobalFS()->GetFileSize(part_file_path); + SnapshotFS()->NewRandomAccessFile(part_file_path, &part_file); + uint64_t part_file_size = SnapshotFS()->GetFileSize(part_file_path); uint64_t offset = 0; while (offset < part_file_size) { uint64_t n = std::min(buffer.size(), part_file_size - offset); @@ -54,15 +54,15 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num, out_stream.Write(buffer.data(), n); offset += n; } - GlobalFS()->DelFile(part_file_path); + SnapshotFS()->DelFile(part_file_path); } } - GlobalFS()->DeleteDir(part_dir); + SnapshotFS()->DeleteDir(part_dir); std::string snapshot_done_path = JoinPath(root_path_, "snapshot_done"); int32_t snapshot_done_cnt = Global::Get()->IncreaseCount(snapshot_done_path); if (snapshot_done_cnt == Global::Get()->total_mbn_num()) { Global::Get()->EraseCount(snapshot_done_path); - PersistentOutStream out_stream(GlobalFS(), snapshot_done_path); + PersistentOutStream out_stream(SnapshotFS(), snapshot_done_path); } } diff --git a/oneflow/core/persistence/snapshot_manager.cpp b/oneflow/core/persistence/snapshot_manager.cpp index b1034b7cd542412cee85b17161e818f1c20b9500..1891342a0762a37f428b90da052ee9a481425797 100644 --- a/oneflow/core/persistence/snapshot_manager.cpp +++ b/oneflow/core/persistence/snapshot_manager.cpp @@ -7,7 +7,7 @@ namespace oneflow { SnapshotMgr::SnapshotMgr(const Plan& plan) { if (Global::Get()->enable_write_snapshot()) { model_save_snapshots_path_ = Global::Get()->MdSaveSnapshotsPath(); - OfCallOnce(model_save_snapshots_path_, GlobalFS(), &fs::FileSystem::MakeEmptyDir); + OfCallOnce(model_save_snapshots_path_, SnapshotFS(), &fs::FileSystem::MakeEmptyDir); } const std::string& load_path = Global::Get()->MdLoadSnapshotPath(); if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); } @@ -21,7 +21,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) { if (it == snapshot_id2writeable_snapshot_.end()) { std::string snapshot_root_path = JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id)); - OfCallOnce(snapshot_root_path, GlobalFS(), &fs::FileSystem::CreateDirIfNotExist); + OfCallOnce(snapshot_root_path, SnapshotFS(), &fs::FileSystem::CreateDirIfNotExist); std::unique_ptr ret(new Snapshot(snapshot_root_path)); auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret)); it = emplace_ret.first; diff --git a/oneflow/core/persistence/snapshot_test.cpp b/oneflow/core/persistence/snapshot_test.cpp index aa6987f1932911ce1171162993df1c86ccddf1fd..81d98e2943bd1ac08ada9641fb94a1537574598a 100644 --- a/oneflow/core/persistence/snapshot_test.cpp +++ b/oneflow/core/persistence/snapshot_test.cpp @@ -8,19 +8,21 @@ namespace oneflow { TEST(Snapshot, write_and_read) { JobDescProto jb_desc_proto; auto job_conf = jb_desc_proto.mutable_job_conf(); - auto gfs_conf = job_conf->mutable_global_fs_conf(); - gfs_conf->set_allocated_localfs_conf(new LocalFsConf); + auto job_other = job_conf->mutable_other(); + auto snapshot_path_conf = job_conf->mutable_other()->mutable_snapshot_path_conf(); + snapshot_path_conf->set_allocated_localfs_conf(new LocalFsConf); auto resource = jb_desc_proto.mutable_resource(); resource->add_machine(); Global::Get()->InitFromProto(jb_desc_proto); + fs::FileSystem* snapshot_fs = GetFS(Global::Get()->snapshot_path_conf()); std::string current_dir = GetCwd(); StringReplace(¤t_dir, '\\', '/'); std::string snapshot_root_path = JoinPath(current_dir, "/tmp_snapshot_test_asdfasdf"); - if (GlobalFS()->IsDirectory(snapshot_root_path)) { - ASSERT_TRUE(GlobalFS()->ListDir(snapshot_root_path).empty()); + if (snapshot_fs->IsDirectory(snapshot_root_path)) { + ASSERT_TRUE(snapshot_fs->ListDir(snapshot_root_path).empty()); } else { - GlobalFS()->CreateDir(snapshot_root_path); + snapshot_fs->CreateDir(snapshot_root_path); } std::string key = "key/name"; @@ -41,13 +43,13 @@ TEST(Snapshot, write_and_read) { // read { auto read_stream_ptr = - std::make_unique(GlobalFS(), JoinPath(snapshot_root_path, key)); + std::make_unique(snapshot_fs, JoinPath(snapshot_root_path, key)); std::string content; read_stream_ptr->ReadLine(&content); ASSERT_EQ(content, "ab"); } - GlobalFS()->RecursivelyDeleteDir(snapshot_root_path); - ASSERT_TRUE(!GlobalFS()->IsDirectory(snapshot_root_path)); + snapshot_fs->RecursivelyDeleteDir(snapshot_root_path); + ASSERT_TRUE(!snapshot_fs->IsDirectory(snapshot_root_path)); } } // namespace oneflow diff --git a/oneflow/core/persistence/tee_persistent_log_stream.cpp b/oneflow/core/persistence/tee_persistent_log_stream.cpp new file mode 100644 index 0000000000000000000000000000000000000000..560a66ab5a64d4d505e41cd682dbc8bd81ee667e --- /dev/null +++ b/oneflow/core/persistence/tee_persistent_log_stream.cpp @@ -0,0 +1,39 @@ +#include "oneflow/core/persistence/tee_persistent_log_stream.h" +#include "oneflow/core/common/str_util.h" +#include + +namespace oneflow { + +TeePersistentLogStream::TeePersistentLogStream(const std::string& path) { + destinations_.emplace_back(LocalFS(), LogDir()); + branches_.reserve(destinations_.size()); + for (const auto& destination : destinations_) { + branches_.emplace_back(std::make_unique( + destination.mut_file_system(), JoinPath(destination.base_dir(), path))); + } +} + +TeePersistentLogStream::~TeePersistentLogStream() { Flush(); } + +std::unique_ptr TeePersistentLogStream::Create(const std::string& path) { + auto stream_ptr = new TeePersistentLogStream(path); + return std::unique_ptr(stream_ptr); +} + +void TeePersistentLogStream::Flush() { + for (const auto& branch : branches_) { branch->Flush(); } +}; + +void TeePersistentLogStream::Write(const char* s, size_t n) { + for (const auto& branch : branches_) { branch->Write(s, n); } +}; + +void TeePersistentLogStream::Write(const std::string& str) { this->Write(str.data(), str.size()); } + +void TeePersistentLogStream::Write(const PbMessage& proto) { + std::string output; + google::protobuf::TextFormat::PrintToString(proto, &output); + this->Write(output); +} + +} // namespace oneflow diff --git a/oneflow/core/persistence/tee_persistent_log_stream.h b/oneflow/core/persistence/tee_persistent_log_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..efe21e30ddb659810c413115e1416e17e594fdd6 --- /dev/null +++ b/oneflow/core/persistence/tee_persistent_log_stream.h @@ -0,0 +1,54 @@ +#ifndef ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_ +#define ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_ + +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/persistence/persistent_out_stream.h" + +namespace oneflow { + +class LogStreamDestination final { + public: + LogStreamDestination(fs::FileSystem* file_system, const std::string& base_dir) + : file_system_(file_system), base_dir_(base_dir) {} + ~LogStreamDestination() = default; + fs::FileSystem* mut_file_system() const { return file_system_; }; + const std::string& base_dir() const { return base_dir_; }; + + private: + fs::FileSystem* file_system_; + std::string base_dir_; +}; + +class TeePersistentLogStream final { + public: + OF_DISALLOW_COPY_AND_MOVE(TeePersistentLogStream); + ~TeePersistentLogStream(); + + void Write(const char* s, size_t n); + void Write(const std::string& str); + void Write(const PbMessage& proto); + + static std::unique_ptr Create(const std::string& path); + + private: + explicit TeePersistentLogStream(const std::string& path); + void Flush(); + std::vector destinations_; + std::vector> branches_; +}; + +inline TeePersistentLogStream& operator<<(TeePersistentLogStream& log_stream, + const std::string& s) { + log_stream.Write(s.c_str(), s.size()); + return log_stream; +} + +inline std::unique_ptr& operator<<( + std::unique_ptr& log_stream, const std::string& s) { + log_stream->Write(s.c_str(), s.size()); + return log_stream; +} + +} // namespace oneflow + +#endif // ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_