未验证 提交 d76513b3 编写于 作者: J Jinhui Yuan 提交者: GitHub

Refine runtime (#1108)

* only master machine saves plan and has event logger

* separate Data, Persistence, Cache, Log FileSystem config

* refine

* only specify data and snapshot path conf

* forbit multiple machines use localfs as snapshot fs

* networkfs as localfs

* refine

* Store log to snapshot (#1109)

* use machine id, drop machine name

* ensure setting machine id

* allow save snapshot to localfs for distributed training (#1113)

* Snapshot to master (#1116)

* allow save snapshot to localfs for distributed training

* fix mdSave to master for model parallel

* fix review comment issues

* add sanity check for machine id

* rm useless comments

* update example

* Dev refine runtime add log stream mgr (#1142)

* add LogStreamMgr

* refine and refactor OutStream=>LogStream

* bugfix

* use LogStreamMgr to write graph, dot, plan, profile and proto

* refine

* simplify, remove LogStreamMgr (#1243)

* simplify, remove LogStreamMgr

* TeePersistentLogStream add static factory (#1244)
上级 0da0646c
machine { machine {
addr: "127.0.0.1" addr: "127.0.0.1"
port: 7099 port: 7099
name: "first" id: 0
} }
gpu_device_num: 2 gpu_device_num: 2
globalfs_conf { data_fs_conf {
localfs_conf {
}
}
snapshot_fs_conf {
localfs_conf { localfs_conf {
} }
} }
......
...@@ -5,7 +5,7 @@ placement_group { ...@@ -5,7 +5,7 @@ placement_group {
} }
parallel_conf { parallel_conf {
policy: kDataParallel policy: kDataParallel
device_name: "first:cpu:0-1" device_name: "0:cpu:0-1"
} }
} }
...@@ -25,6 +25,6 @@ placement_group { ...@@ -25,6 +25,6 @@ placement_group {
} }
parallel_conf { parallel_conf {
policy: kDataParallel policy: kDataParallel
device_name: "first:gpu:0-1" device_name: "0:gpu:0-1"
} }
} }
globalfs_conf { data_fs_conf {
localfs_conf {
}
}
snapshot_fs_conf {
localfs_conf { localfs_conf {
} }
} }
......
...@@ -4,7 +4,7 @@ placement_group { ...@@ -4,7 +4,7 @@ placement_group {
} }
parallel_conf { parallel_conf {
policy: kDataParallel policy: kDataParallel
device_name: "first:cpu:0-1" device_name: "0:cpu:0-1"
} }
} }
...@@ -24,6 +24,6 @@ placement_group { ...@@ -24,6 +24,6 @@ placement_group {
} }
parallel_conf { parallel_conf {
policy: kDataParallel policy: kDataParallel
device_name: "first:gpu:0-1" device_name: "0:gpu:0-1"
} }
} }
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <sys/sysinfo.h> #include <sys/sysinfo.h>
#endif #endif
DEFINE_int64(this_machine_id, -1, "");
namespace oneflow { namespace oneflow {
#define DEFINE_ONEFLOW_STR2INT_CAST(dst_type, cast_func) \ #define DEFINE_ONEFLOW_STR2INT_CAST(dst_type, cast_func) \
...@@ -86,4 +88,9 @@ size_t GetAvailableCpuMemSize() { ...@@ -86,4 +88,9 @@ size_t GetAvailableCpuMemSize() {
return 0; return 0;
} }
std::string LogDir() {
std::string v = FLAGS_log_dir + "/" + std::to_string(FLAGS_this_machine_id);
return v;
}
} // namespace oneflow } // namespace oneflow
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "oneflow/core/common/meta_util.hpp" #include "oneflow/core/common/meta_util.hpp"
DECLARE_string(log_dir); DECLARE_string(log_dir);
DECLARE_int64(this_machine_id);
namespace std { namespace std {
template<typename T0, typename T1> template<typename T0, typename T1>
...@@ -121,10 +122,7 @@ inline std::string NewUniqueId() { ...@@ -121,10 +122,7 @@ inline std::string NewUniqueId() {
return std::to_string(id++); return std::to_string(id++);
} }
inline const std::string& LogDir() { std::string LogDir();
static std::string v = FLAGS_log_dir;
return v;
}
template<typename K, typename V> template<typename K, typename V>
void EraseIf(HashMap<K, V>* hash_map, std::function<bool(typename HashMap<K, V>::iterator)> cond) { void EraseIf(HashMap<K, V>* hash_map, std::function<bool(typename HashMap<K, V>::iterator)> cond) {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <stack> #include <stack>
#include "oneflow/core/common/str_util.h" #include "oneflow/core/common/str_util.h"
#include "oneflow/core/graph/node.h" #include "oneflow/core/graph/node.h"
#include "oneflow/core/persistence/persistent_out_stream.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h"
namespace oneflow { namespace oneflow {
...@@ -161,15 +161,13 @@ void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) { ...@@ -161,15 +161,13 @@ void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) {
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) { void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) {
std::string dir_name = Dirname(file_path); auto log_stream = TeePersistentLogStream::Create(file_path);
if (!LocalFS()->IsDirectory(dir_name)) { LocalFS()->RecursivelyCreateDir(dir_name); } ToDotWithStream(log_stream);
PersistentOutStream out_stream(LocalFS(), file_path);
ToDotWithStream(out_stream);
} }
template<typename NodeType, typename EdgeType> template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() { void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() {
std::string file_path = LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot"; std::string file_path = JoinPath("dot", TypeName(), NewUniqueId() + ".dot");
ToDotWithFilePath(file_path); ToDotWithFilePath(file_path);
} }
......
...@@ -330,7 +330,7 @@ void LogicalGraph::BuildLossPrintStruct() { ...@@ -330,7 +330,7 @@ void LogicalGraph::BuildLossPrintStruct() {
std::shared_ptr<Operator> loss_print_op = ConstructOp(loss_print_op_conf); std::shared_ptr<Operator> loss_print_op = ConstructOp(loss_print_op_conf);
ParallelConf loss_print_pr_conf; ParallelConf loss_print_pr_conf;
loss_print_pr_conf.set_policy(kDataParallel); loss_print_pr_conf.set_policy(kDataParallel);
loss_print_pr_conf.add_device_name(Global<JobDesc>::Get()->MachineName4MachineId(0) + ":cpu:1"); loss_print_pr_conf.add_device_name("0:cpu:1");
LossPrintLogicalNode* loss_print_logical = NewNode<LossPrintLogicalNode>(); LossPrintLogicalNode* loss_print_logical = NewNode<LossPrintLogicalNode>();
loss_print_logical->mut_op_vec() = {loss_print_op}; loss_print_logical->mut_op_vec() = {loss_print_op};
loss_print_logical->mut_parallel_desc().reset(new ParallelDesc(loss_print_pr_conf)); loss_print_logical->mut_parallel_desc().reset(new ParallelDesc(loss_print_pr_conf));
...@@ -363,8 +363,7 @@ void LogicalGraph::BuildAccuracyPrintStruct() { ...@@ -363,8 +363,7 @@ void LogicalGraph::BuildAccuracyPrintStruct() {
std::shared_ptr<Operator> accuracy_print_op = ConstructOp(accuracy_print_op_conf); std::shared_ptr<Operator> accuracy_print_op = ConstructOp(accuracy_print_op_conf);
ParallelConf accuracy_print_pr_conf; ParallelConf accuracy_print_pr_conf;
accuracy_print_pr_conf.set_policy(kDataParallel); accuracy_print_pr_conf.set_policy(kDataParallel);
accuracy_print_pr_conf.add_device_name(Global<JobDesc>::Get()->MachineName4MachineId(0) accuracy_print_pr_conf.add_device_name("0:cpu:1");
+ ":cpu:1");
AccuracyPrintLogicalNode* accuracy_print_logical = NewNode<AccuracyPrintLogicalNode>(); AccuracyPrintLogicalNode* accuracy_print_logical = NewNode<AccuracyPrintLogicalNode>();
accuracy_print_logical->mut_op_vec() = {accuracy_print_op}; accuracy_print_logical->mut_op_vec() = {accuracy_print_op};
accuracy_print_logical->mut_parallel_desc().reset(new ParallelDesc(accuracy_print_pr_conf)); accuracy_print_logical->mut_parallel_desc().reset(new ParallelDesc(accuracy_print_pr_conf));
...@@ -544,10 +543,13 @@ MdSaveLogicalNode* LogicalGraph::BuildMdSaveStruct(const ForwardLogicalNode* fw_ ...@@ -544,10 +543,13 @@ MdSaveLogicalNode* LogicalGraph::BuildMdSaveStruct(const ForwardLogicalNode* fw_
auto md_save_logical = NewNode<MdSaveLogicalNode>(); auto md_save_logical = NewNode<MdSaveLogicalNode>();
md_save_logical->mut_op_vec() = {model_save_op}; md_save_logical->mut_op_vec() = {model_save_op};
auto md_save_pr_desc = new ParallelDesc(*(fw_logical->parallel_desc())); auto md_save_pr_desc = new ParallelDesc(*(fw_logical->parallel_desc()));
md_save_pr_desc->set_device_type(DeviceType::kCPU);
if (fw_logical->parallel_desc()->policy() == ParallelPolicy::kDataParallel) { if (fw_logical->parallel_desc()->policy() == ParallelPolicy::kDataParallel) {
md_save_pr_desc->RandomSelectOneDeviceAndRemoveTheOthers(); md_save_pr_desc->RandomSelectOneDeviceAndRemoveTheOthers();
} }
md_save_pr_desc->set_device_type(DeviceType::kCPU); if (Global<JobDesc>::Get()->write_snapshot_to_master()) {
md_save_pr_desc->UseCPUDevicesOnMaster();
}
md_save_logical->mut_parallel_desc().reset(md_save_pr_desc); md_save_logical->mut_parallel_desc().reset(md_save_pr_desc);
Connect<LogicalNode>(need_save_logical, NewEdge(), md_save_logical); Connect<LogicalNode>(need_save_logical, NewEdge(), md_save_logical);
return md_save_logical; return md_save_logical;
......
#include "oneflow/core/job/compiler.h" #include "oneflow/core/job/compiler.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/device/cudnn_conv_ctx_cache.h" #include "oneflow/core/device/cudnn_conv_ctx_cache.h"
namespace oneflow { namespace oneflow {
...@@ -6,11 +7,11 @@ namespace oneflow { ...@@ -6,11 +7,11 @@ namespace oneflow {
namespace { namespace {
void ToDotFile(const Plan& plan, const std::string& filepath) { void ToDotFile(const Plan& plan, const std::string& filepath) {
PersistentOutStream out_stream(LocalFS(), filepath); auto log_stream = TeePersistentLogStream::Create(filepath);
out_stream << "digraph {\n"; log_stream << "digraph {\n";
HashSet<int64_t> regst_desc_ids; HashSet<int64_t> regst_desc_ids;
for (const TaskProto& task_proto : plan.task()) { for (const TaskProto& task_proto : plan.task()) {
out_stream << "task" << std::to_string(task_proto.task_id()) << "[label=\"" log_stream << "task" << std::to_string(task_proto.task_id()) << "[label=\""
<< std::to_string(task_proto.task_id()) << "\\n" << std::to_string(task_proto.task_id()) << "\\n"
<< std::to_string(task_proto.machine_id()) << ":" << std::to_string(task_proto.machine_id()) << ":"
<< std::to_string(task_proto.thrd_id()) << ":" << std::to_string(task_proto.thrd_id()) << ":"
...@@ -23,23 +24,23 @@ void ToDotFile(const Plan& plan, const std::string& filepath) { ...@@ -23,23 +24,23 @@ void ToDotFile(const Plan& plan, const std::string& filepath) {
} }
} }
for (const int64_t regst_task_id : regst_desc_ids) { for (const int64_t regst_task_id : regst_desc_ids) {
out_stream << "regst_desc" << std::to_string(regst_task_id) << "[label=\"" log_stream << "regst_desc" << std::to_string(regst_task_id) << "[label=\""
<< std::to_string(regst_task_id) << "\", shape=box];\n"; << std::to_string(regst_task_id) << "\", shape=box];\n";
} }
for (const TaskProto& task_proto : plan.task()) { for (const TaskProto& task_proto : plan.task()) {
for (const auto& pair : task_proto.produced_regst_desc()) { for (const auto& pair : task_proto.produced_regst_desc()) {
out_stream << "task" << std::to_string(task_proto.task_id()) << "->regst_desc" log_stream << "task" << std::to_string(task_proto.task_id()) << "->regst_desc"
<< std::to_string(pair.second.regst_desc_id()) << "[label=\"" << pair.first << std::to_string(pair.second.regst_desc_id()) << "[label=\"" << pair.first
<< "\"];\n"; << "\"];\n";
} }
for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (const auto& pair : task_proto.consumed_regst_desc_id()) {
for (int64_t regst_desc_id : pair.second.regst_desc_id()) { for (int64_t regst_desc_id : pair.second.regst_desc_id()) {
out_stream << "regst_desc" << std::to_string(regst_desc_id) << "->task" log_stream << "regst_desc" << std::to_string(regst_desc_id) << "->task"
<< std::to_string(task_proto.task_id()) << "[label=\"" << pair.first << "\"];\n"; << std::to_string(task_proto.task_id()) << "[label=\"" << pair.first << "\"];\n";
} }
} }
} }
out_stream << "}\n"; log_stream << "}\n";
} }
} // namespace } // namespace
...@@ -120,7 +121,7 @@ Plan Compiler::DoCompile() { ...@@ -120,7 +121,7 @@ Plan Compiler::DoCompile() {
}); });
plan.set_total_mbn_num(total_mbn_num); plan.set_total_mbn_num(total_mbn_num);
GenNetTopo(&plan); GenNetTopo(&plan);
ToDotFile(plan, JoinPath(LogDir(), "/dot/plan.dot")); ToDotFile(plan, "/dot/plan.dot");
#ifdef WITH_CUDA #ifdef WITH_CUDA
Global<CudnnConvCtxCache>::Delete(); Global<CudnnConvCtxCache>::Delete();
#endif #endif
......
...@@ -27,19 +27,22 @@ message PredictConf { ...@@ -27,19 +27,22 @@ message PredictConf {
message LocalFsConf { message LocalFsConf {
} }
message NetworkFsConf {
}
message HdfsConf { message HdfsConf {
required string namenode = 1; required string namenode = 1;
} }
message GlobalFSConf { message FileSystemConf {
oneof fs_type { oneof fs_type {
LocalFsConf localfs_conf = 1; LocalFsConf localfs_conf = 1;
HdfsConf hdfs_conf = 2; NetworkFsConf networkfs_conf = 2;
HdfsConf hdfs_conf = 3;
} }
} }
message OtherConf { message OtherConf {
required GlobalFSConf globalfs_conf = 1;
required int64 piece_size = 2; required int64 piece_size = 2;
required int32 data_part_num = 3; // piece_size % data_part_num = 0 required int32 data_part_num = 3; // piece_size % data_part_num = 0
...@@ -64,6 +67,9 @@ message OtherConf { ...@@ -64,6 +67,9 @@ message OtherConf {
optional bool enable_write_snapshot = 130 [default = true]; optional bool enable_write_snapshot = 130 [default = true];
optional bool enable_blob_mem_sharing = 140 [default = true]; optional bool enable_blob_mem_sharing = 140 [default = true];
required FileSystemConf data_fs_conf = 121;
required FileSystemConf snapshot_fs_conf = 122;
oneof JobType { oneof JobType {
TrainConf train_conf = 200; TrainConf train_conf = 200;
PredictConf predict_conf = 201; PredictConf predict_conf = 201;
......
...@@ -5,15 +5,6 @@ ...@@ -5,15 +5,6 @@
namespace oneflow { namespace oneflow {
int64_t JobDesc::MachineID4MachineName(const std::string& machine_name) const {
auto it = machine_name2machine_id_.find(machine_name);
CHECK(it != machine_name2machine_id_.end()) << "Undefined machine name: " << machine_name;
return it->second;
}
const std::string& JobDesc::MachineName4MachineId(int64_t machine_id) const {
return machine_id2machine_name_.at(machine_id);
}
int64_t JobDesc::piece_num_of_experiment_phase() const { int64_t JobDesc::piece_num_of_experiment_phase() const {
return job_conf_.other().piece_num_of_experiment_phase(); return job_conf_.other().piece_num_of_experiment_phase();
} }
...@@ -87,6 +78,11 @@ float JobDesc::L2() const { ...@@ -87,6 +78,11 @@ float JobDesc::L2() const {
int32_t JobDesc::DataPartNum() const { return job_conf_.other().data_part_num(); } int32_t JobDesc::DataPartNum() const { return job_conf_.other().data_part_num(); }
const FileSystemConf& JobDesc::data_fs_conf() const { return job_conf_.other().data_fs_conf(); }
const FileSystemConf& JobDesc::snapshot_fs_conf() const {
return job_conf_.other().snapshot_fs_conf();
}
JobDesc::JobDesc(const std::string& job_conf_filepath) { JobDesc::JobDesc(const std::string& job_conf_filepath) {
if (TryParseProtoFromTextFile(job_conf_filepath, &job_conf_) == false) { if (TryParseProtoFromTextFile(job_conf_filepath, &job_conf_) == false) {
JobConf2 job_conf; JobConf2 job_conf;
...@@ -96,7 +92,7 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) { ...@@ -96,7 +92,7 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) {
ParseProtoFromTextFile(job_conf.placement(), job_conf_.mutable_placement()); ParseProtoFromTextFile(job_conf.placement(), job_conf_.mutable_placement());
ParseProtoFromTextFile(job_conf.other(), job_conf_.mutable_other()); ParseProtoFromTextFile(job_conf.other(), job_conf_.mutable_other());
} }
SanityCheck();
SplitDecodeOps(); SplitDecodeOps();
AddRecordLoadOps(); AddRecordLoadOps();
#ifndef WITH_RDMA #ifndef WITH_RDMA
...@@ -123,12 +119,13 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) { ...@@ -123,12 +119,13 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) {
#ifndef WITH_CUDA #ifndef WITH_CUDA
CHECK_EQ(job_conf_.resource().gpu_device_num(), 0); CHECK_EQ(job_conf_.resource().gpu_device_num(), 0);
#endif #endif
}
void JobDesc::SanityCheck() {
int64_t machine_num = job_conf_.resource().machine_size(); int64_t machine_num = job_conf_.resource().machine_size();
for (int64_t i = 0; i < machine_num; ++i) { FOR_RANGE(int64_t, i, 0, machine_num) { CHECK_EQ(job_conf_.resource().machine(i).id(), i); }
const std::string& machine_name = job_conf_.resource().machine(i).name(); CHECK_GE(FLAGS_this_machine_id, 0);
CHECK(machine_name2machine_id_.emplace(machine_name, i).second); CHECK_LT(FLAGS_this_machine_id, machine_num);
CHECK(machine_id2machine_name_.emplace(i, machine_name).second);
}
} }
void JobDesc::SplitDecodeOps() { void JobDesc::SplitDecodeOps() {
......
...@@ -47,16 +47,15 @@ class JobDesc final { ...@@ -47,16 +47,15 @@ class JobDesc final {
size_t rdma_recv_msg_buf_byte() const; size_t rdma_recv_msg_buf_byte() const;
bool collect_act_event() const { return job_conf_.other().collect_act_event(); } bool collect_act_event() const { return job_conf_.other().collect_act_event(); }
bool enable_mem_sharing() const { return job_conf_.other().enable_mem_sharing(); } bool enable_mem_sharing() const { return job_conf_.other().enable_mem_sharing(); }
const FileSystemConf& data_fs_conf() const;
const FileSystemConf& snapshot_fs_conf() const;
bool enable_write_snapshot() const { bool enable_write_snapshot() const {
return IsTrain() && job_conf_.other().enable_write_snapshot(); return IsTrain() && job_conf_.other().enable_write_snapshot();
} }
bool write_snapshot_to_master() const { return snapshot_fs_conf().has_localfs_conf(); }
bool enable_blob_mem_sharing() const { return job_conf_.other().enable_blob_mem_sharing(); } bool enable_blob_mem_sharing() const { return job_conf_.other().enable_blob_mem_sharing(); }
int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); } int64_t reduce_group_size() const { return job_conf_.other().reduce_group_size(); }
// machine_name <-> machine_id
int64_t MachineID4MachineName(const std::string& machine_name) const;
const std::string& MachineName4MachineId(int64_t machine_id) const;
// Train conf // Train conf
const std::string& MdSaveSnapshotsPath() const; const std::string& MdSaveSnapshotsPath() const;
int32_t NumOfBatchesInSnapshot() const; int32_t NumOfBatchesInSnapshot() const;
...@@ -73,13 +72,11 @@ class JobDesc final { ...@@ -73,13 +72,11 @@ class JobDesc final {
private: private:
friend class Global<JobDesc>; friend class Global<JobDesc>;
JobDesc(const std::string& job_conf_filepath); JobDesc(const std::string& job_conf_filepath);
void SanityCheck();
void SplitDecodeOps(); void SplitDecodeOps();
void AddRecordLoadOps(); void AddRecordLoadOps();
JobConf1 job_conf_; JobConf1 job_conf_;
HashMap<std::string, int64_t> machine_name2machine_id_;
HashMap<int64_t, std::string> machine_id2machine_name_;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -7,9 +7,7 @@ std::string MachineCtx::GetCtrlAddr(int64_t machine_id) const { ...@@ -7,9 +7,7 @@ std::string MachineCtx::GetCtrlAddr(int64_t machine_id) const {
return mchn.addr() + ":" + std::to_string(mchn.port()); return mchn.addr() + ":" + std::to_string(mchn.port());
} }
MachineCtx::MachineCtx(const std::string& this_mchn_name) { MachineCtx::MachineCtx(int64_t this_mchn_id) : this_machine_id_(this_mchn_id) {
this_machine_id_ = Global<JobDesc>::Get()->MachineID4MachineName(this_mchn_name);
LOG(INFO) << "this machine name: " << this_mchn_name;
LOG(INFO) << "this machine id: " << this_machine_id_; LOG(INFO) << "this machine id: " << this_machine_id_;
} }
......
...@@ -19,7 +19,7 @@ class MachineCtx final { ...@@ -19,7 +19,7 @@ class MachineCtx final {
private: private:
friend class Global<MachineCtx>; friend class Global<MachineCtx>;
MachineCtx(const std::string& this_mchn_name); MachineCtx(int64_t this_mchn_id);
int64_t this_machine_id_; int64_t this_machine_id_;
}; };
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/plan.pb.h"
#include "oneflow/core/job/runtime.h" #include "oneflow/core/job/runtime.h"
#include "oneflow/core/job/available_memory_desc.pb.h" #include "oneflow/core/job/available_memory_desc.pb.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/persistence/file_system.h" #include "oneflow/core/persistence/file_system.h"
#include "oneflow/core/actor/act_event_logger.h" #include "oneflow/core/actor/act_event_logger.h"
...@@ -150,10 +151,10 @@ bool HasRelayPlacement() { ...@@ -150,10 +151,10 @@ bool HasRelayPlacement() {
for (const PlacementGroup& p_group : placement.placement_group()) { for (const PlacementGroup& p_group : placement.placement_group()) {
const ParallelConf& p_conf = p_group.parallel_conf(); const ParallelConf& p_conf = p_group.parallel_conf();
for (const std::string& device_name : p_conf.device_name()) { for (const std::string& device_name : p_conf.device_name()) {
std::string mchn_name; int64_t mchn_id;
std::string device_tag; std::string device_tag;
std::string device_id_str; std::string device_id_str;
ParseDeviceNameConf(device_name, &mchn_name, &device_tag, &device_id_str); ParseDeviceNameConf(device_name, &mchn_id, &device_tag, &device_id_str);
if (device_tag == "cpu") { if (device_tag == "cpu") {
break; break;
} else if (device_tag == "gpu") { } else if (device_tag == "gpu") {
...@@ -177,16 +178,16 @@ class Oneflow final { ...@@ -177,16 +178,16 @@ class Oneflow final {
OF_DISALLOW_COPY_AND_MOVE(Oneflow); OF_DISALLOW_COPY_AND_MOVE(Oneflow);
~Oneflow() = default; ~Oneflow() = default;
Oneflow(const std::string& job_conf_filepath, const std::string& this_mchn_name); Oneflow(const std::string& job_conf_filepath, int64_t this_mchn_id);
private: private:
std::unique_ptr<CtrlServer> ctrl_server_; std::unique_ptr<CtrlServer> ctrl_server_;
}; };
Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_mchn_name) { Oneflow::Oneflow(const std::string& job_conf_filepath, int64_t this_mchn_id) {
// New All Global // New All Global
Global<JobDesc>::New(job_conf_filepath); Global<JobDesc>::New(job_conf_filepath);
Global<MachineCtx>::New(this_mchn_name); Global<MachineCtx>::New(this_mchn_id);
const MachineCtx* machine_ctx = Global<MachineCtx>::Get(); const MachineCtx* machine_ctx = Global<MachineCtx>::Get();
bool DoProfile = bool DoProfile =
machine_ctx->IsThisMachineMaster() && Global<JobDesc>::Get()->collect_act_event(); machine_ctx->IsThisMachineMaster() && Global<JobDesc>::Get()->collect_act_event();
...@@ -216,15 +217,15 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m ...@@ -216,15 +217,15 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
PullPlan("mem_shared_plan", &mem_shared_plan); PullPlan("mem_shared_plan", &mem_shared_plan);
} }
OF_BARRIER(); OF_BARRIER();
PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan")); TeePersistentLogStream::Create("naive_plan")->Write(naive_plan);
PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan")); TeePersistentLogStream::Create("mem_shared_plan")->Write(mem_shared_plan);
LOG(INFO) << "push_pull_plan:" << GetCurTime() - start; LOG(INFO) << "push_pull_plan:" << GetCurTime() - start;
if (HasRelayPlacement()) { if (HasRelayPlacement()) {
// Experiment Runtime // Experiment Runtime
{ Runtime experiment_run(mem_shared_plan, true); } { Runtime experiment_run(mem_shared_plan, true); }
// Improve // Improve
if (machine_ctx->IsThisMachineMaster()) { if (machine_ctx->IsThisMachineMaster()) {
PrintProtoToTextFile(amd, JoinPath(LogDir(), "available_mem_desc")); TeePersistentLogStream::Create("available_mem_desc")->Write(amd);
CHECK_GT(amd.machine_amd_size(), 0); CHECK_GT(amd.machine_amd_size(), 0);
improved_plan = Improver().Improve( improved_plan = Improver().Improve(
amd, naive_plan, JoinPath(LogDir(), ActEventLogger::experiment_act_event_bin_filename())); amd, naive_plan, JoinPath(LogDir(), ActEventLogger::experiment_act_event_bin_filename()));
...@@ -233,7 +234,7 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m ...@@ -233,7 +234,7 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
PullPlan("improved_plan", &improved_plan); PullPlan("improved_plan", &improved_plan);
} }
OF_BARRIER(); OF_BARRIER();
PrintProtoToTextFile(improved_plan, JoinPath(LogDir(), "improved_plan")); TeePersistentLogStream::Create("improved_plan")->Write(improved_plan);
Global<CtrlClient>::Get()->Clear(); Global<CtrlClient>::Get()->Clear();
OF_BARRIER(); OF_BARRIER();
} else { } else {
...@@ -257,16 +258,16 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m ...@@ -257,16 +258,16 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
} // namespace oneflow } // namespace oneflow
DEFINE_string(job_conf, "", ""); DEFINE_string(job_conf, "", "");
DEFINE_string(this_machine_name, "", "");
int main(int argc, char** argv) { int main(int argc, char** argv) {
using namespace oneflow; using namespace oneflow;
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
gflags::SetVersionString(BuildVersionString()); gflags::SetVersionString(BuildVersionString());
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
CHECK_GE(FLAGS_this_machine_id, 0);
LocalFS()->RecursivelyCreateDirIfNotExist(LogDir()); LocalFS()->RecursivelyCreateDirIfNotExist(LogDir());
RedirectStdoutAndStderrToGlogDir(); RedirectStdoutAndStderrToGlogDir();
{ Oneflow flow(FLAGS_job_conf, FLAGS_this_machine_name); } { Oneflow flow(FLAGS_job_conf, FLAGS_this_machine_id); }
CloseStdoutAndStderr(); CloseStdoutAndStderr();
return 0; return 0;
} }
#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/common/util.h"
namespace oneflow { namespace oneflow {
void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag,
std::string* device_tag, std::string* device_id_str) { std::string* device_id_str) {
size_t second_delimiter_pos = device_name.rfind(":"); size_t second_delimiter_pos = device_name.rfind(":");
CHECK_NE(second_delimiter_pos, std::string::npos); CHECK_NE(second_delimiter_pos, std::string::npos);
size_t first_delimiter_pos = device_name.rfind(":", second_delimiter_pos - 1); size_t first_delimiter_pos = device_name.rfind(":", second_delimiter_pos - 1);
CHECK_NE(first_delimiter_pos, std::string::npos); CHECK_NE(first_delimiter_pos, std::string::npos);
*mchn_name = device_name.substr(0, first_delimiter_pos); *mchn_id = oneflow_cast<int64_t>(device_name.substr(0, first_delimiter_pos));
*device_tag = *device_tag =
device_name.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1); device_name.substr(first_delimiter_pos + 1, second_delimiter_pos - first_delimiter_pos - 1);
*device_id_str = device_name.substr(second_delimiter_pos + 1); *device_id_str = device_name.substr(second_delimiter_pos + 1);
...@@ -16,14 +17,14 @@ void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, ...@@ -16,14 +17,14 @@ void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name,
ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
policy_ = user_conf.policy(); policy_ = user_conf.policy();
HashSet<std::string> machine_name_set; HashSet<int64_t> machine_id_set;
device_type_ = DeviceType::kInvalidDevice; device_type_ = DeviceType::kInvalidDevice;
for (const std::string& device_name : user_conf.device_name()) { for (const std::string& device_name : user_conf.device_name()) {
std::string mchn_name; int64_t mchn_id;
std::string device_tag; std::string device_tag;
std::string device_id_str; std::string device_id_str;
ParseDeviceNameConf(device_name, &mchn_name, &device_tag, &device_id_str); ParseDeviceNameConf(device_name, &mchn_id, &device_tag, &device_id_str);
machine_name_set.insert(mchn_name); machine_id_set.insert(mchn_id);
if (device_tag == "cpu") { if (device_tag == "cpu") {
CHECK_STREQ(device_tag.c_str(), "cpu"); CHECK_STREQ(device_tag.c_str(), "cpu");
CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kCPU); CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kCPU);
...@@ -33,8 +34,7 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { ...@@ -33,8 +34,7 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kGPU); CHECK(device_type_ == DeviceType::kInvalidDevice || device_type_ == DeviceType::kGPU);
device_type_ = DeviceType::kGPU; device_type_ = DeviceType::kGPU;
} }
int64_t machine_id = Global<JobDesc>::Get()->MachineID4MachineName(mchn_name); sorted_machine_ids_.push_back(mchn_id);
sorted_machine_ids_.push_back(machine_id);
int64_t minus_pos = device_id_str.find("-"); int64_t minus_pos = device_id_str.find("-");
if (minus_pos == std::string::npos) { if (minus_pos == std::string::npos) {
device_id_str = device_id_str + "-" + device_id_str; device_id_str = device_id_str + "-" + device_id_str;
...@@ -47,11 +47,11 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) { ...@@ -47,11 +47,11 @@ ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
if (device_type_ == DeviceType::kGPU) { if (device_type_ == DeviceType::kGPU) {
CHECK_LT(dev_phy_id, Global<JobDesc>::Get()->GpuDeviceNum()); CHECK_LT(dev_phy_id, Global<JobDesc>::Get()->GpuDeviceNum());
} }
machine_id2sorted_dev_phy_ids_[machine_id].push_back(dev_phy_id); machine_id2sorted_dev_phy_ids_[mchn_id].push_back(dev_phy_id);
} }
} }
ClearUp(); ClearUp();
CheckValidity(); SanityCheck();
} }
void ParallelDesc::RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num) { void ParallelDesc::RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num) {
...@@ -104,6 +104,13 @@ void ParallelDesc::RandomSelectOneDeviceAndRemoveTheOthers() { ...@@ -104,6 +104,13 @@ void ParallelDesc::RandomSelectOneDeviceAndRemoveTheOthers() {
parallel_num_ = 1; parallel_num_ = 1;
} }
void ParallelDesc::UseCPUDevicesOnMaster() {
sorted_machine_ids_ = {0};
std::vector<int64_t> sorted_dev_phy_ids(parallel_num_);
std::iota(sorted_dev_phy_ids.begin(), sorted_dev_phy_ids.end(), 0);
machine_id2sorted_dev_phy_ids_ = {{0, sorted_dev_phy_ids}};
}
bool ParallelDesc::Equal(const ParallelDesc& rhs) const { bool ParallelDesc::Equal(const ParallelDesc& rhs) const {
return device_type_ == rhs.device_type_ && policy_ == rhs.policy_ return device_type_ == rhs.device_type_ && policy_ == rhs.policy_
&& sorted_machine_ids_ == rhs.sorted_machine_ids_ && sorted_machine_ids_ == rhs.sorted_machine_ids_
...@@ -124,7 +131,7 @@ void ParallelDesc::ClearUp() { ...@@ -124,7 +131,7 @@ void ParallelDesc::ClearUp() {
SortAndRemoveDuplication(&sorted_machine_ids_); SortAndRemoveDuplication(&sorted_machine_ids_);
} }
void ParallelDesc::CheckValidity() { void ParallelDesc::SanityCheck() {
device_num_of_each_machine_ = -1; device_num_of_each_machine_ = -1;
for (auto& pair : machine_id2sorted_dev_phy_ids_) { for (auto& pair : machine_id2sorted_dev_phy_ids_) {
if (device_num_of_each_machine_ == -1) { if (device_num_of_each_machine_ == -1) {
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace oneflow { namespace oneflow {
void ParseDeviceNameConf(const std::string& device_name, std::string* mchn_name, void ParseDeviceNameConf(const std::string& device_name, int64_t* mchn_id, std::string* device_tag,
std::string* device_tag, std::string* device_id_str); std::string* device_id_str);
class ParallelDesc { class ParallelDesc {
public: public:
...@@ -35,6 +35,7 @@ class ParallelDesc { ...@@ -35,6 +35,7 @@ class ParallelDesc {
void RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num); void RemoveNeedlessDevice(const std::string& op_name, int32_t max_device_num);
void RemoveNeedlessDevice(int32_t max_device_num) { RemoveNeedlessDevice("", max_device_num); } void RemoveNeedlessDevice(int32_t max_device_num) { RemoveNeedlessDevice("", max_device_num); }
void RandomSelectOneDeviceAndRemoveTheOthers(); void RandomSelectOneDeviceAndRemoveTheOthers();
void UseCPUDevicesOnMaster();
// //
bool Equal(const ParallelDesc& rhs) const; bool Equal(const ParallelDesc& rhs) const;
...@@ -42,7 +43,7 @@ class ParallelDesc { ...@@ -42,7 +43,7 @@ class ParallelDesc {
private: private:
void ClearUp(); void ClearUp();
void CheckValidity(); void SanityCheck();
DeviceType device_type_; DeviceType device_type_;
ParallelPolicy policy_; ParallelPolicy policy_;
......
#include "oneflow/core/job/profiler.h" #include "oneflow/core/job/profiler.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
#include "oneflow/core/persistence/persistent_out_stream.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/common/str_util.h" #include "oneflow/core/common/str_util.h"
#include "oneflow/core/actor/act_event_logger.h" #include "oneflow/core/actor/act_event_logger.h"
...@@ -84,7 +84,7 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath) ...@@ -84,7 +84,7 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath)
[](const ProfileInfoPair& lhs, const ProfileInfoPair& rhs) { [](const ProfileInfoPair& lhs, const ProfileInfoPair& rhs) {
return lhs.second.CalcBottleNeckScore() > rhs.second.CalcBottleNeckScore(); return lhs.second.CalcBottleNeckScore() > rhs.second.CalcBottleNeckScore();
}); });
PersistentOutStream out_stream(LocalFS(), JoinPath(LogDir(), "oneflow.profile")); auto log_stream = TeePersistentLogStream::Create("oneflow.profile");
double mdupdt_act_interval = 0.0; double mdupdt_act_interval = 0.0;
int32_t mdupdt_task_num = 0; int32_t mdupdt_task_num = 0;
for (const ProfileInfoPair& pair : profile_info_vec) { for (const ProfileInfoPair& pair : profile_info_vec) {
...@@ -93,10 +93,10 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath) ...@@ -93,10 +93,10 @@ void Profiler::Profile(const Plan& plan, const std::string& act_event_filepath)
mdupdt_act_interval += pair.second.avg_act_interval(); mdupdt_act_interval += pair.second.avg_act_interval();
} }
} }
out_stream << "time_of_one_batch:" << std::to_string(mdupdt_act_interval / mdupdt_task_num) log_stream << "time_of_one_batch:" << std::to_string(mdupdt_act_interval / mdupdt_task_num)
<< "\n"; << "\n";
for (const ProfileInfoPair& pair : profile_info_vec) { for (const ProfileInfoPair& pair : profile_info_vec) {
out_stream << "actor_id:" << std::to_string(pair.first) log_stream << "actor_id:" << std::to_string(pair.first)
<< " act_num: " << std::to_string(pair.second.act_num()) << " act_num: " << std::to_string(pair.second.act_num())
<< " avg_act_time:" << std::to_string(pair.second.avg_act_time()) << " avg_act_time:" << std::to_string(pair.second.avg_act_time())
<< " avg_act_interval:" << std::to_string(pair.second.avg_act_interval()) << " avg_act_interval:" << std::to_string(pair.second.avg_act_interval())
......
...@@ -4,7 +4,7 @@ package oneflow; ...@@ -4,7 +4,7 @@ package oneflow;
message Machine { message Machine {
required string addr = 1; // domain name or ip required string addr = 1; // domain name or ip
required int32 port = 2; required int32 port = 2;
required string name = 3; required int64 id = 3;
} }
enum DeviceType { enum DeviceType {
......
...@@ -89,7 +89,8 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) { ...@@ -89,7 +89,8 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) {
} }
} }
Global<RuntimeCtx>::New(piece_num, is_experiment_phase); Global<RuntimeCtx>::New(piece_num, is_experiment_phase);
if (Global<RuntimeCtx>::Get()->NeedCollectActEvent()) { if (Global<MachineCtx>::Get()->IsThisMachineMaster()
&& Global<RuntimeCtx>::Get()->NeedCollectActEvent()) {
Global<ActEventLogger>::New(is_experiment_phase); Global<ActEventLogger>::New(is_experiment_phase);
} }
if (job_desc->TotalMachineNum() > 1) { if (job_desc->TotalMachineNum() > 1) {
......
...@@ -263,11 +263,11 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num ...@@ -263,11 +263,11 @@ KU_IF_METHOD InitializeWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num
int64_t blob_size = blob->ByteSizeOfDataContentField(); int64_t blob_size = blob->ByteSizeOfDataContentField();
int64_t byte_size_of_each_dim = num_in_each_dim * sizeof(T); int64_t byte_size_of_each_dim = num_in_each_dim * sizeof(T);
std::string file_path = JoinPath(model_dir, bn_in_op); std::string file_path = JoinPath(model_dir, bn_in_op);
uint64_t file_size = GlobalFS()->GetFileSize(file_path); uint64_t file_size = SnapshotFS()->GetFileSize(file_path);
CHECK_EQ(file_size, dim_num * byte_size_of_each_dim); CHECK_EQ(file_size, dim_num * byte_size_of_each_dim);
BalancedSplitter splitter = BalancedSplitter(dim_num, part_num); BalancedSplitter splitter = BalancedSplitter(dim_num, part_num);
int64_t begin_pos = splitter.At(part_id).begin() * byte_size_of_each_dim; int64_t begin_pos = splitter.At(part_id).begin() * byte_size_of_each_dim;
PersistentInStream in_stream(GlobalFS(), file_path, begin_pos); PersistentInStream in_stream(SnapshotFS(), file_path, begin_pos);
in_stream.Read(blob->mut_dptr<char>(), blob_size); in_stream.Read(blob->mut_dptr<char>(), blob_size);
} }
......
...@@ -7,13 +7,13 @@ namespace oneflow { ...@@ -7,13 +7,13 @@ namespace oneflow {
void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
const auto& conf = op_conf().print_conf(); const auto& conf = op_conf().print_conf();
const std::string& root_path = conf.print_dir(); const std::string& root_path = conf.print_dir();
OfCallOnce(root_path, GlobalFS(), &fs::FileSystem::RecursivelyCreateDir); OfCallOnce(root_path, SnapshotFS(), &fs::FileSystem::RecursivelyCreateDir);
int32_t part_name_suffix_length = conf.part_name_suffix_length(); int32_t part_name_suffix_length = conf.part_name_suffix_length();
std::string num = std::to_string(parallel_ctx->parallel_id()); std::string num = std::to_string(parallel_ctx->parallel_id());
int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0); int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);
std::string file_path = std::string file_path =
JoinPath(root_path, conf.part_name_prefix() + std::string(zero_count, '0') + num); JoinPath(root_path, conf.part_name_prefix() + std::string(zero_count, '0') + num);
out_stream_.reset(new PersistentOutStream(GlobalFS(), file_path)); out_stream_.reset(new PersistentOutStream(SnapshotFS(), file_path));
} }
void PrintKernel::Forward(const KernelCtx& ctx, void PrintKernel::Forward(const KernelCtx& ctx,
......
...@@ -22,13 +22,10 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { ...@@ -22,13 +22,10 @@ void RecordLoadKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
data_paths.push_back(JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num)); data_paths.push_back(JoinPath(data_dir, part_name_prefix + std::string(zero_count, '0') + num));
} }
if (Global<JobDesc>::Get()->IsTrain()) { if (Global<JobDesc>::Get()->IsTrain()) {
if (Global<JobDesc>::Get()->save_downloaded_file_to_local_fs() && GlobalFS() != LocalFS()) { in_stream_.reset(new PersistentInStream(
in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, true, true)); DataFS(), data_paths, true, Global<JobDesc>::Get()->save_downloaded_file_to_local_fs()));
} else {
in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, true, false));
}
} else { } else {
in_stream_.reset(new PersistentInStream(GlobalFS(), data_paths, false, false)); in_stream_.reset(new PersistentInStream(DataFS(), data_paths, false, false));
} }
int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize(); int64_t global_piece_size = Global<JobDesc>::Get()->PieceSize();
CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0); CHECK_EQ(global_piece_size % parallel_ctx->parallel_num(), 0);
......
...@@ -90,21 +90,6 @@ void FileSystem::RecursivelyCreateDir(const std::string& dirname) { ...@@ -90,21 +90,6 @@ void FileSystem::RecursivelyCreateDir(const std::string& dirname) {
} }
} }
struct GlobalFSConstructor {
GlobalFSConstructor() {
const GlobalFSConf& gfs_conf = Global<JobDesc>::Get()->other_conf().globalfs_conf();
if (gfs_conf.has_localfs_conf()) {
CHECK_EQ(Global<JobDesc>::Get()->resource().machine().size(), 1);
gfs = LocalFS();
} else if (gfs_conf.has_hdfs_conf()) {
gfs = new HadoopFileSystem(gfs_conf.hdfs_conf());
} else {
UNIMPLEMENTED();
}
}
FileSystem* gfs;
};
} // namespace fs } // namespace fs
fs::FileSystem* LocalFS() { fs::FileSystem* LocalFS() {
...@@ -116,9 +101,25 @@ fs::FileSystem* LocalFS() { ...@@ -116,9 +101,25 @@ fs::FileSystem* LocalFS() {
return fs; return fs;
} }
fs::FileSystem* GlobalFS() { fs::FileSystem* NetworkFS() { return LocalFS(); }
static fs::GlobalFSConstructor gfs_constructor;
return gfs_constructor.gfs; fs::FileSystem* HadoopFS(const HdfsConf& hdfs_conf) {
static fs::FileSystem* fs = new fs::HadoopFileSystem(hdfs_conf);
return fs;
}
fs::FileSystem* GetFS(const FileSystemConf& file_system_conf) {
if (file_system_conf.has_localfs_conf()) {
return LocalFS();
} else if (file_system_conf.has_networkfs_conf()) {
return NetworkFS();
} else if (file_system_conf.has_hdfs_conf()) {
return HadoopFS(file_system_conf.hdfs_conf());
} else {
UNIMPLEMENTED();
}
} }
fs::FileSystem* DataFS() { return GetFS(Global<JobDesc>::Get()->data_fs_conf()); }
fs::FileSystem* SnapshotFS() { return GetFS(Global<JobDesc>::Get()->snapshot_fs_conf()); }
} // namespace oneflow } // namespace oneflow
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "oneflow/core/common/platform.h" #include "oneflow/core/common/platform.h"
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_conf.pb.h"
namespace oneflow { namespace oneflow {
...@@ -153,8 +154,10 @@ class FileSystem { ...@@ -153,8 +154,10 @@ class FileSystem {
} // namespace fs } // namespace fs
fs::FileSystem* LocalFS(); fs::FileSystem* LocalFS();
fs::FileSystem* GlobalFS();
fs::FileSystem* GetFS(const FileSystemConf& file_system_conf);
fs::FileSystem* DataFS();
fs::FileSystem* SnapshotFS();
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_ #endif // ONEFLOW_CORE_PERSISTENCE_FILE_SYSTEM_H_
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace oneflow { namespace oneflow {
Snapshot::Snapshot(const std::string& snapshot_root_path) { Snapshot::Snapshot(const std::string& snapshot_root_path) {
CHECK(GlobalFS()->IsDirectory(snapshot_root_path)); CHECK(SnapshotFS()->IsDirectory(snapshot_root_path));
root_path_ = snapshot_root_path; root_path_ = snapshot_root_path;
} }
...@@ -18,13 +18,13 @@ std::unique_ptr<PersistentOutStream> Snapshot::GetOutStream(const LogicalBlobId& ...@@ -18,13 +18,13 @@ std::unique_ptr<PersistentOutStream> Snapshot::GetOutStream(const LogicalBlobId&
int32_t part_id) { int32_t part_id) {
// op_name_dir // op_name_dir
std::string op_name_dir = JoinPath(root_path_, lbi.op_name()); std::string op_name_dir = JoinPath(root_path_, lbi.op_name());
OfCallOnce(op_name_dir, GlobalFS(), &fs::FileSystem::CreateDir); OfCallOnce(op_name_dir, SnapshotFS(), &fs::FileSystem::CreateDir);
// bn_in_op_tmp_dir // bn_in_op_tmp_dir
std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58"); std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58");
OfCallOnce(bn_in_op_tmp_dir, GlobalFS(), &fs::FileSystem::CreateDir); OfCallOnce(bn_in_op_tmp_dir, SnapshotFS(), &fs::FileSystem::CreateDir);
// part_file // part_file
std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id)); std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id));
return std::make_unique<PersistentOutStream>(GlobalFS(), part_file); return std::make_unique<PersistentOutStream>(SnapshotFS(), part_file);
} }
void Snapshot::OnePartDone(const LogicalBlobId& lbi, int32_t part_id, int32_t part_num) { void Snapshot::OnePartDone(const LogicalBlobId& lbi, int32_t part_id, int32_t part_num) {
...@@ -41,12 +41,12 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num, ...@@ -41,12 +41,12 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num,
std::vector<char> buffer(Global<JobDesc>::Get()->persistence_buf_byte()); std::vector<char> buffer(Global<JobDesc>::Get()->persistence_buf_byte());
std::string part_dir = JoinPath(root_path_, lbi.op_name(), lbi.blob_name() + "_tmp4a58"); std::string part_dir = JoinPath(root_path_, lbi.op_name(), lbi.blob_name() + "_tmp4a58");
{ {
PersistentOutStream out_stream(GlobalFS(), concat_file); PersistentOutStream out_stream(SnapshotFS(), concat_file);
for (int32_t i = 0; i < part_num; ++i) { for (int32_t i = 0; i < part_num; ++i) {
std::unique_ptr<fs::RandomAccessFile> part_file; std::unique_ptr<fs::RandomAccessFile> part_file;
std::string part_file_path = JoinPath(part_dir, "part_" + std::to_string(i)); std::string part_file_path = JoinPath(part_dir, "part_" + std::to_string(i));
GlobalFS()->NewRandomAccessFile(part_file_path, &part_file); SnapshotFS()->NewRandomAccessFile(part_file_path, &part_file);
uint64_t part_file_size = GlobalFS()->GetFileSize(part_file_path); uint64_t part_file_size = SnapshotFS()->GetFileSize(part_file_path);
uint64_t offset = 0; uint64_t offset = 0;
while (offset < part_file_size) { while (offset < part_file_size) {
uint64_t n = std::min(buffer.size(), part_file_size - offset); uint64_t n = std::min(buffer.size(), part_file_size - offset);
...@@ -54,15 +54,15 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num, ...@@ -54,15 +54,15 @@ void Snapshot::ConcatLbnFile(const LogicalBlobId& lbi, int32_t part_num,
out_stream.Write(buffer.data(), n); out_stream.Write(buffer.data(), n);
offset += n; offset += n;
} }
GlobalFS()->DelFile(part_file_path); SnapshotFS()->DelFile(part_file_path);
} }
} }
GlobalFS()->DeleteDir(part_dir); SnapshotFS()->DeleteDir(part_dir);
std::string snapshot_done_path = JoinPath(root_path_, "snapshot_done"); std::string snapshot_done_path = JoinPath(root_path_, "snapshot_done");
int32_t snapshot_done_cnt = Global<CtrlClient>::Get()->IncreaseCount(snapshot_done_path); int32_t snapshot_done_cnt = Global<CtrlClient>::Get()->IncreaseCount(snapshot_done_path);
if (snapshot_done_cnt == Global<SnapshotMgr>::Get()->total_mbn_num()) { if (snapshot_done_cnt == Global<SnapshotMgr>::Get()->total_mbn_num()) {
Global<CtrlClient>::Get()->EraseCount(snapshot_done_path); Global<CtrlClient>::Get()->EraseCount(snapshot_done_path);
PersistentOutStream out_stream(GlobalFS(), snapshot_done_path); PersistentOutStream out_stream(SnapshotFS(), snapshot_done_path);
} }
} }
......
...@@ -7,7 +7,7 @@ namespace oneflow { ...@@ -7,7 +7,7 @@ namespace oneflow {
SnapshotMgr::SnapshotMgr(const Plan& plan) { SnapshotMgr::SnapshotMgr(const Plan& plan) {
if (Global<JobDesc>::Get()->enable_write_snapshot()) { if (Global<JobDesc>::Get()->enable_write_snapshot()) {
model_save_snapshots_path_ = Global<JobDesc>::Get()->MdSaveSnapshotsPath(); model_save_snapshots_path_ = Global<JobDesc>::Get()->MdSaveSnapshotsPath();
OfCallOnce(model_save_snapshots_path_, GlobalFS(), &fs::FileSystem::MakeEmptyDir); OfCallOnce(model_save_snapshots_path_, SnapshotFS(), &fs::FileSystem::MakeEmptyDir);
} }
const std::string& load_path = Global<JobDesc>::Get()->MdLoadSnapshotPath(); const std::string& load_path = Global<JobDesc>::Get()->MdLoadSnapshotPath();
if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); } if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); }
...@@ -21,7 +21,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) { ...@@ -21,7 +21,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) {
if (it == snapshot_id2writeable_snapshot_.end()) { if (it == snapshot_id2writeable_snapshot_.end()) {
std::string snapshot_root_path = std::string snapshot_root_path =
JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id)); JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id));
OfCallOnce(snapshot_root_path, GlobalFS(), &fs::FileSystem::CreateDirIfNotExist); OfCallOnce(snapshot_root_path, SnapshotFS(), &fs::FileSystem::CreateDirIfNotExist);
std::unique_ptr<Snapshot> ret(new Snapshot(snapshot_root_path)); std::unique_ptr<Snapshot> ret(new Snapshot(snapshot_root_path));
auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret)); auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret));
it = emplace_ret.first; it = emplace_ret.first;
......
...@@ -8,19 +8,21 @@ namespace oneflow { ...@@ -8,19 +8,21 @@ namespace oneflow {
TEST(Snapshot, write_and_read) { TEST(Snapshot, write_and_read) {
JobDescProto jb_desc_proto; JobDescProto jb_desc_proto;
auto job_conf = jb_desc_proto.mutable_job_conf(); auto job_conf = jb_desc_proto.mutable_job_conf();
auto gfs_conf = job_conf->mutable_global_fs_conf(); auto job_other = job_conf->mutable_other();
gfs_conf->set_allocated_localfs_conf(new LocalFsConf); auto snapshot_path_conf = job_conf->mutable_other()->mutable_snapshot_path_conf();
snapshot_path_conf->set_allocated_localfs_conf(new LocalFsConf);
auto resource = jb_desc_proto.mutable_resource(); auto resource = jb_desc_proto.mutable_resource();
resource->add_machine(); resource->add_machine();
Global<JobDesc>::Get()->InitFromProto(jb_desc_proto); Global<JobDesc>::Get()->InitFromProto(jb_desc_proto);
fs::FileSystem* snapshot_fs = GetFS(Global<JobDesc>::Get()->snapshot_path_conf());
std::string current_dir = GetCwd(); std::string current_dir = GetCwd();
StringReplace(&current_dir, '\\', '/'); StringReplace(&current_dir, '\\', '/');
std::string snapshot_root_path = JoinPath(current_dir, "/tmp_snapshot_test_asdfasdf"); std::string snapshot_root_path = JoinPath(current_dir, "/tmp_snapshot_test_asdfasdf");
if (GlobalFS()->IsDirectory(snapshot_root_path)) { if (snapshot_fs->IsDirectory(snapshot_root_path)) {
ASSERT_TRUE(GlobalFS()->ListDir(snapshot_root_path).empty()); ASSERT_TRUE(snapshot_fs->ListDir(snapshot_root_path).empty());
} else { } else {
GlobalFS()->CreateDir(snapshot_root_path); snapshot_fs->CreateDir(snapshot_root_path);
} }
std::string key = "key/name"; std::string key = "key/name";
...@@ -41,13 +43,13 @@ TEST(Snapshot, write_and_read) { ...@@ -41,13 +43,13 @@ TEST(Snapshot, write_and_read) {
// read // read
{ {
auto read_stream_ptr = auto read_stream_ptr =
std::make_unique<NormalPersistentInStream>(GlobalFS(), JoinPath(snapshot_root_path, key)); std::make_unique<NormalPersistentInStream>(snapshot_fs, JoinPath(snapshot_root_path, key));
std::string content; std::string content;
read_stream_ptr->ReadLine(&content); read_stream_ptr->ReadLine(&content);
ASSERT_EQ(content, "ab"); ASSERT_EQ(content, "ab");
} }
GlobalFS()->RecursivelyDeleteDir(snapshot_root_path); snapshot_fs->RecursivelyDeleteDir(snapshot_root_path);
ASSERT_TRUE(!GlobalFS()->IsDirectory(snapshot_root_path)); ASSERT_TRUE(!snapshot_fs->IsDirectory(snapshot_root_path));
} }
} // namespace oneflow } // namespace oneflow
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/common/str_util.h"
#include <google/protobuf/text_format.h>
namespace oneflow {
TeePersistentLogStream::TeePersistentLogStream(const std::string& path) {
destinations_.emplace_back(LocalFS(), LogDir());
branches_.reserve(destinations_.size());
for (const auto& destination : destinations_) {
branches_.emplace_back(std::make_unique<PersistentOutStream>(
destination.mut_file_system(), JoinPath(destination.base_dir(), path)));
}
}
TeePersistentLogStream::~TeePersistentLogStream() { Flush(); }
std::unique_ptr<TeePersistentLogStream> TeePersistentLogStream::Create(const std::string& path) {
auto stream_ptr = new TeePersistentLogStream(path);
return std::unique_ptr<TeePersistentLogStream>(stream_ptr);
}
void TeePersistentLogStream::Flush() {
for (const auto& branch : branches_) { branch->Flush(); }
};
void TeePersistentLogStream::Write(const char* s, size_t n) {
for (const auto& branch : branches_) { branch->Write(s, n); }
};
void TeePersistentLogStream::Write(const std::string& str) { this->Write(str.data(), str.size()); }
void TeePersistentLogStream::Write(const PbMessage& proto) {
std::string output;
google::protobuf::TextFormat::PrintToString(proto, &output);
this->Write(output);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_
#define ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
namespace oneflow {
class LogStreamDestination final {
public:
LogStreamDestination(fs::FileSystem* file_system, const std::string& base_dir)
: file_system_(file_system), base_dir_(base_dir) {}
~LogStreamDestination() = default;
fs::FileSystem* mut_file_system() const { return file_system_; };
const std::string& base_dir() const { return base_dir_; };
private:
fs::FileSystem* file_system_;
std::string base_dir_;
};
class TeePersistentLogStream final {
public:
OF_DISALLOW_COPY_AND_MOVE(TeePersistentLogStream);
~TeePersistentLogStream();
void Write(const char* s, size_t n);
void Write(const std::string& str);
void Write(const PbMessage& proto);
static std::unique_ptr<TeePersistentLogStream> Create(const std::string& path);
private:
explicit TeePersistentLogStream(const std::string& path);
void Flush();
std::vector<LogStreamDestination> destinations_;
std::vector<std::unique_ptr<PersistentOutStream>> branches_;
};
inline TeePersistentLogStream& operator<<(TeePersistentLogStream& log_stream,
const std::string& s) {
log_stream.Write(s.c_str(), s.size());
return log_stream;
}
inline std::unique_ptr<TeePersistentLogStream>& operator<<(
std::unique_ptr<TeePersistentLogStream>& log_stream, const std::string& s) {
log_stream->Write(s.c_str(), s.size());
return log_stream;
}
} // namespace oneflow
#endif // ONEFLOW_CORE_PERSISTENCE_TEE_PERSISTENT_LOG_STREAM_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册