提交 5780964c 编写于 作者: W Wind5

Merge branch 'master' of github.com:Oneflow-Inc/oneflow into dev_chenjw

......@@ -33,6 +33,8 @@ void Actor::Init(const TaskProto& task_proto) {
for (const auto& pair : task_proto.subscribed_regst_desc_id()) {
CHECK(name2regst_desc_id_.emplace(pair.first, pair.second).second);
}
//
expected_piece_id_ = 0;
// Status of Produced Registers
for (const auto& regst : produced_regst_vec_) {
writeable_produced_regst_[regst->regst_desc_id()].push(regst.get());
......@@ -52,6 +54,7 @@ void Actor::WardKernel(
return regst->GetBlobPtrFromLbn(lbn);
});
}
expected_piece_id_ += 1;
}
void Actor::ForEachProducedRegst(std::function<void(Regst*)> func) {
......@@ -60,7 +63,7 @@ void Actor::ForEachProducedRegst(std::function<void(Regst*)> func) {
}
}
int Actor::TryOneReadDone(Regst* regst) {
int Actor::TryUpdtStateAsFromRegstReader(Regst* regst) {
auto reading_cnt_it = produced_regst2reading_cnt_.find(regst);
if (reading_cnt_it == produced_regst2reading_cnt_.end()) { return -1; }
CHECK_GE(reading_cnt_it->second, 1);
......
......@@ -39,8 +39,9 @@ class Actor {
return name2regst_desc_id_.at(name);
}
uint64_t expected_piece_id() const { return expected_piece_id_; }
// Status of Produced Registers
int TryOneReadDone(Regst* regst);
int TryUpdtStateAsFromRegstReader(Regst* regst);
Regst* GetCurWriteableRegst(uint64_t regst_desc_id);
Regst* GetCurWriteableRegst(const std::string& name);
void ForEachCurWriteableRegst(std::function<void(Regst*)> func);
......@@ -59,6 +60,7 @@ class Actor {
std::vector<std::unique_ptr<Regst>> produced_regst_vec_;
HashMap<std::string, uint64_t> name2regst_desc_id_;
uint64_t expected_piece_id_;
// Status of Produced Registers
HashMap<uint64_t, std::queue<Regst*>> writeable_produced_regst_; // <regst_desc_id, regst>
uint64_t writeable_produced_regst_desc_num_;
......
......@@ -8,7 +8,8 @@ namespace oneflow {
enum class ActorCmd {
kInitializeModel = 0,
kSendInitialModel
kSendInitialModel,
kStop
};
enum class ActorMsgType {
......
......@@ -12,7 +12,7 @@ void BoxingActor::Init(const TaskProto& task_proto) {
void BoxingActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
KernelContext kernel_ctx;
if (TryOneReadDone(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) {
std::shared_ptr<RegstWarpper> regst_wp = msg.regst_warpper();
auto waiting_in_regst_it = waiting_in_regst_.find(regst_wp->piece_id());
if (waiting_in_regst_it == waiting_in_regst_.end()) {
......
......@@ -11,7 +11,7 @@ void CopyCommNetActor::Init(const TaskProto& task_proto) {
void CopyCommNetActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext&) {
KernelContext kernel_ctx;
if (TryOneReadDone(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) {
waiting_in_regst_.push(std::move(msg.regst_warpper()));
}
if (!waiting_in_regst_.empty() && IsWriteReady()) {
......
......@@ -54,9 +54,17 @@ void MdUpdtCompActor::HandleBeforeSendInitialModel(
}
void MdUpdtCompActor::HandleForUpdateModel(
const ActorMsg&,
const KernelContext&) {
TODO();
const ActorMsg& actor_msg,
const KernelContext& kernel_ctx) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kStop);
cur_handle_ = nullptr;
TODO();
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
TODO();
} else {
UNEXPECTED_RUN();
}
}
REGISTER_ACTOR(kMdUpdtCompTask, true, MdUpdtCompActor);
......
......@@ -11,7 +11,6 @@ message JobConf {
string strategy_filepath = 3;
string model_load_snapshot_path = 4;
string model_save_snapshots_path = 5;
uint32 batch_size = 6;
uint32 piece_size = 7;
bool is_train = 8;
FloatingPointType floating_point_type = 9;
......
......@@ -10,7 +10,6 @@ void JobDesc::InitFromJobConf(const JobConf& conf) {
ParseProtoFromTextFile(conf.strategy_filepath(), &strategy_);
md_load_snapshot_path_ = conf.model_load_snapshot_path();
md_save_snapshots_path_ = conf.model_save_snapshots_path();
batch_size_ = conf.batch_size();
piece_size_ = conf.piece_size();
is_train_ = conf.is_train();
floating_point_type_ = conf.floating_point_type();
......@@ -23,7 +22,6 @@ void JobDesc::InitFromProto(const JobDescProto& proto) {
strategy_ = proto.strategy();
md_load_snapshot_path_ = proto.model_load_snapshot_path();
md_save_snapshots_path_ = proto.model_save_snapshots_path();
batch_size_ = proto.batch_size();
piece_size_ = proto.piece_size();
is_train_ = proto.is_train();
floating_point_type_ = proto.floating_point_type();
......@@ -35,7 +33,6 @@ void JobDesc::ToProto(JobDescProto* proto) const {
*(proto->mutable_strategy()) = strategy_;
*(proto->mutable_model_load_snapshot_path()) = md_load_snapshot_path_;
*(proto->mutable_model_save_snapshots_path()) = md_save_snapshots_path_;
proto->set_batch_size(batch_size_);
proto->set_piece_size(piece_size_);
proto->set_is_train(is_train_);
proto->set_floating_point_type(floating_point_type_);
......
......@@ -45,7 +45,6 @@ class JobDesc final {
Strategy strategy_;
std::string md_load_snapshot_path_;
std::string md_save_snapshots_path_;
uint32_t batch_size_;
uint32_t piece_size_;
uint32_t num_of_piece_in_batch_;
uint32_t staleness_;
......
......@@ -12,7 +12,6 @@ message JobDescProto {
Strategy strategy = 3;
string model_load_snapshot_path = 4;
string model_save_snapshots_path = 5;
uint32 batch_size = 6;
uint32 piece_size = 7;
bool is_train = 8;
FloatingPointType floating_point_type = 9;
......
......@@ -41,7 +41,7 @@ void Snapshot::CheckAndConcat() {
CHECK_EQ(file_names[0], "0");
env_->RenameFile(file_path, concat_file_path);
}
return;
continue;
}
// if the children number is more than 1, every child must be a file,
// and the file name must be {0,1,2 ... n} , n is the part number
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册