提交 3d5244c8 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

moving model (#1234)

* moving model

* moving_model => forward_model

* add todo commit

* two model save node

* let md_updt actor handle forward_model

* remove useless code

* rename local variable


Former-commit-id: baa146bd
上级 33868c01
...@@ -6,6 +6,12 @@ namespace oneflow { ...@@ -6,6 +6,12 @@ namespace oneflow {
void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
model_regst_desc_id_ = Name2SoleRegstDescId("model"); model_regst_desc_id_ = Name2SoleRegstDescId("model");
const_model_regst_desc_id_ = Name2SoleRegstDescId("const_model"); const_model_regst_desc_id_ = Name2SoleRegstDescId("const_model");
int64_t forward_model_regst_desc_id = Name2SoleRegstDescId("forward_model");
if (forward_model_regst_desc_id != -1) {
forward_model_regst_ = GetCurWriteableRegst(forward_model_regst_desc_id);
} else {
forward_model_regst_ = nullptr;
}
init_remaining_cnt_ = 0; init_remaining_cnt_ = 0;
if (model_regst_desc_id_ != -1) { init_remaining_cnt_ += 1; } if (model_regst_desc_id_ != -1) { init_remaining_cnt_ += 1; }
if (const_model_regst_desc_id_ != -1) { if (const_model_regst_desc_id_ != -1) {
...@@ -13,7 +19,9 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { ...@@ -13,7 +19,9 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
DecreaseActualWriteableProducedDataRegstDescNum(1); DecreaseActualWriteableProducedDataRegstDescNum(1);
} }
next_model_version_id_ = 0; next_model_version_id_ = 0;
related_save_model_actor_id_ = task_proto.related_save_model_task_id(); for (int64_t model_save_related_actor_id : task_proto.related_save_model_task_ids()) {
related_save_model_actor_ids_.insert(model_save_related_actor_id);
}
related_init_model_actor_id_ = task_proto.related_init_model_task_id(); related_init_model_actor_id_ = task_proto.related_init_model_task_id();
pre_model_regst_ = nullptr; pre_model_regst_ = nullptr;
OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel); OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel);
...@@ -35,12 +43,25 @@ void NormalMdUpdtCompActor::Act() { ...@@ -35,12 +43,25 @@ void NormalMdUpdtCompActor::Act() {
bool need_save_model = NeedModelSave(next_model_version_id_ - 1); bool need_save_model = NeedModelSave(next_model_version_id_ - 1);
bool need_send_model = next_model_version_id_ < job_desc->TotalBatchNum(); bool need_send_model = next_model_version_id_ < job_desc->TotalBatchNum();
AsyncSendRegstMsgToConsumer(RegstPreProcess, [&](int64_t actor_id) { AsyncSendRegstMsgToConsumer(RegstPreProcess, [&](int64_t actor_id) {
return (need_save_model && actor_id == related_save_model_actor_id_) bool is_saving_related =
|| (need_send_model && actor_id != related_save_model_actor_id_); related_save_model_actor_ids_.find(actor_id) != related_save_model_actor_ids_.end();
return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related);
}); });
if (need_save_model && forward_model_regst_ != nullptr) {
AsyncSendRegstMsgToConsumer([&](Regst* regst) { return regst == forward_model_regst_; });
}
next_model_version_id_ += 1; next_model_version_id_ += 1;
} }
int64_t NormalMdUpdtCompActor::ActNumForEachOutput(int64_t regst_desc_id) const {
const auto* job_desc = Global<JobDesc>::Get();
if (forward_model_regst_ != nullptr && regst_desc_id == forward_model_regst_->regst_desc_id()) {
return std::min<int64_t>(job_desc->TotalBatchNum() - (forward_model_regst_->act_id() + 1),
job_desc->NumOfBatchesInSnapshot());
}
return 1;
}
void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) { void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) {
if (regst_desc_id == -1) { return; } if (regst_desc_id == -1) { return; }
Regst* regst = GetCurWriteableRegst(regst_desc_id); Regst* regst = GetCurWriteableRegst(regst_desc_id);
...@@ -48,11 +69,30 @@ void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) { ...@@ -48,11 +69,30 @@ void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) {
Global<ActorMsgBus>::Get()->SendMsg(msg); Global<ActorMsgBus>::Get()->SendMsg(msg);
} }
void NormalMdUpdtCompActor::InitModelAndConstBuf() {
// TODO move the initiation of model and const model from fw op into this function
if (forward_model_regst_ == nullptr) { return; }
for (const ExecKernel& ek : exec_kernel_vec()) {
KernelCtx kernel_ctx = GenDefaultKernelCtx();
ek.kernel->InitModelAndConstBuf(kernel_ctx, parallel_ctx(),
Global<SnapshotMgr>::Get()->GetReadableSnapshot(),
[&](const std::string& bn_in_op) {
const LogicalBlobId& lbi = ek.kernel->BnInOp2Lbi(bn_in_op);
Blob* blob = nullptr;
if (forward_model_regst_) {
blob = forward_model_regst_->GetBlobByLbi(lbi);
}
return blob;
});
}
}
int NormalMdUpdtCompActor::HandlerInitModelAndConstModel(const ActorMsg& msg) { int NormalMdUpdtCompActor::HandlerInitModelAndConstModel(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kInitModel); CHECK_EQ(msg.actor_cmd(), ActorCmd::kInitModel);
InitRegstBySendToFw(model_regst_desc_id_); InitRegstBySendToFw(model_regst_desc_id_);
InitRegstBySendToFw(const_model_regst_desc_id_); InitRegstBySendToFw(const_model_regst_desc_id_);
InitModelAndConstBuf();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
init_remaining_cnt_ -= 1; init_remaining_cnt_ -= 1;
} else { } else {
......
...@@ -18,16 +18,18 @@ class NormalMdUpdtCompActor final : public CompActor { ...@@ -18,16 +18,18 @@ class NormalMdUpdtCompActor final : public CompActor {
return {true, {}}; return {true, {}};
} }
bool CheckOutputActId(int64_t regst_desc_id) const override; bool CheckOutputActId(int64_t regst_desc_id) const override;
void InitModelAndConstBuf();
void InitRegstBySendToFw(int64_t regst_desc_id); void InitRegstBySendToFw(int64_t regst_desc_id);
int HandlerInitModelAndConstModel(const ActorMsg&); int HandlerInitModelAndConstModel(const ActorMsg&);
int HandlerSendInitialModel(const ActorMsg&); int HandlerSendInitialModel(const ActorMsg&);
int64_t ActNumForEachOutput(int64_t regst_desc_id) const override;
int64_t model_regst_desc_id_; int64_t model_regst_desc_id_;
int64_t const_model_regst_desc_id_; int64_t const_model_regst_desc_id_;
Regst* forward_model_regst_;
int8_t init_remaining_cnt_; int8_t init_remaining_cnt_;
int64_t next_model_version_id_; int64_t next_model_version_id_;
int64_t related_save_model_actor_id_; HashSet<int64_t> related_save_model_actor_ids_;
int64_t related_init_model_actor_id_; int64_t related_init_model_actor_id_;
Regst* pre_model_regst_; Regst* pre_model_regst_;
}; };
......
...@@ -560,7 +560,13 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct( ...@@ -560,7 +560,13 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct(
NormalMdUpdtLogicalNode* md_updt_logical = NewNode<NormalMdUpdtLogicalNode>(); NormalMdUpdtLogicalNode* md_updt_logical = NewNode<NormalMdUpdtLogicalNode>();
md_updt_logical->mut_parallel_desc() = fw_logical->parallel_desc(); md_updt_logical->mut_parallel_desc() = fw_logical->parallel_desc();
if (Global<JobDesc>::Get()->enable_write_snapshot()) { if (Global<JobDesc>::Get()->enable_write_snapshot()) {
// for model
BuildMdSaveStruct(fw_logical, md_updt_logical); BuildMdSaveStruct(fw_logical, md_updt_logical);
// TODO: remove the following ugly hard coded `if'
if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()) {
// for forward_model
BuildMdSaveStruct(fw_logical, md_updt_logical);
}
} }
return md_updt_logical; return md_updt_logical;
} }
......
...@@ -22,9 +22,12 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { ...@@ -22,9 +22,12 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
int32_t max_model_regst = 1; int32_t max_model_regst = 1;
auto model_regst = ProduceRegst("model", false, 1, max_model_regst); auto model_regst = ProduceRegst("model", false, 1, max_model_regst);
auto const_model_regst = ProduceRegst("const_model", false, 1, 1); auto const_model_regst = ProduceRegst("const_model", false, 1, 1);
auto forward_model_regst = ProduceRegst("forward_model", false, 1, 1);
ProduceRegst("processed_model_diff", false, 1, 1); ProduceRegst("processed_model_diff", false, 1, 1);
ProduceRegst("data_tmp", false, 1, 1); ProduceRegst("data_tmp", false, 1, 1);
related_init_model_task_id_ = -1; related_init_model_task_id_ = -1;
std::list<std::pair<std::string, std::shared_ptr<RegstDesc>>> model_to_save{
{"model", model_regst}, {"forward_model", forward_model_regst}};
for (TaskEdge* out_edge : out_edges()) { for (TaskEdge* out_edge : out_edges()) {
TaskNode* dst_node = out_edge->dst_node(); TaskNode* dst_node = out_edge->dst_node();
if (IsForwardTaskType(dst_node->GetTaskType()) || IsBackwardTaskType(dst_node->GetTaskType())) { if (IsForwardTaskType(dst_node->GetTaskType()) || IsBackwardTaskType(dst_node->GetTaskType())) {
...@@ -36,7 +39,8 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { ...@@ -36,7 +39,8 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
related_init_model_task_id_ = fw_node->task_id(); related_init_model_task_id_ = fw_node->task_id();
} }
} else { } else {
out_edge->AddRegst("model", model_regst); out_edge->AddRegst(model_to_save.front().first, model_to_save.front().second);
model_to_save.pop_front();
} }
} }
} }
...@@ -75,7 +79,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { ...@@ -75,7 +79,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
ExecEdge* exec_edge = nullptr; ExecEdge* exec_edge = nullptr;
processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) { processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) {
OperatorConf op_conf; OperatorConf op_conf;
op_conf.set_name("md_update_" + lbi.op_name() + "_" + lbi.blob_name()); op_conf.set_name("model_update-" + lbi.op_name() + "-" + lbi.blob_name());
op_conf.set_device_type(logical_node()->parallel_desc()->device_type()); op_conf.set_device_type(logical_node()->parallel_desc()->device_type());
op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name()); op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name());
op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name()); op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name());
...@@ -112,6 +116,8 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { ...@@ -112,6 +116,8 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst); model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst);
model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model")); model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model"));
model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp")); model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
model_update_node->AddBnToRegstAndBindIt(&Operator::forward_model_bns,
GetProducedRegst("forward_model"));
}); });
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
} }
...@@ -119,6 +125,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { ...@@ -119,6 +125,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
void NormalMdUpdtCompTaskNode::LockRegsts() { void NormalMdUpdtCompTaskNode::LockRegsts() {
GetProducedRegst("processed_model_diff")->Lock(); GetProducedRegst("processed_model_diff")->Lock();
GetProducedRegst("data_tmp")->Lock(); GetProducedRegst("data_tmp")->Lock();
GetProducedRegst("forward_model")->Lock();
} }
void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
...@@ -129,8 +136,7 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { ...@@ -129,8 +136,7 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
} else if (IsBackwardTaskType(node->GetTaskType())) { } else if (IsBackwardTaskType(node->GetTaskType())) {
// do nothing // do nothing
} else { } else {
CHECK_EQ(task_proto->related_save_model_task_id(), -1); task_proto->add_related_save_model_task_ids(node->task_id());
task_proto->set_related_save_model_task_id(node->task_id());
} }
}); });
task_proto->set_related_init_model_task_id(related_init_model_task_id_); task_proto->set_related_init_model_task_id(related_init_model_task_id_);
......
...@@ -68,6 +68,6 @@ message TaskProto { ...@@ -68,6 +68,6 @@ message TaskProto {
// compute task // compute task
optional ParallelContext parallel_ctx = 1000; // CompTask optional ParallelContext parallel_ctx = 1000; // CompTask
optional int64 random_seed = 1001; // ForwardCompTask optional int64 random_seed = 1001; // ForwardCompTask
optional int64 related_save_model_task_id = 1002 [default = -1]; repeated int64 related_save_model_task_ids = 1002;
optional int64 related_init_model_task_id = 1003 [default = -1]; optional int64 related_init_model_task_id = 1003 [default = -1];
}; };
...@@ -46,7 +46,6 @@ class Kernel { ...@@ -46,7 +46,6 @@ class Kernel {
virtual void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num, virtual void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
const std::string& model_load_dir, const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {} std::function<Blob*(const std::string&)> BnInOp2Blob) const {}
virtual ActivationType GetActivationType() const { return ActivationType::kNone; } virtual ActivationType GetActivationType() const { return ActivationType::kNone; }
virtual void Forward(const KernelCtx& ctx, virtual void Forward(const KernelCtx& ctx,
......
...@@ -3,6 +3,26 @@ ...@@ -3,6 +3,26 @@
namespace oneflow { namespace oneflow {
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::InitModelBlobsWithRandomSeed(
DeviceCtx* ctx, std::mt19937* random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
InitializerConf momentum_initializer_conf;
momentum_initializer_conf.mutable_constant_conf()->set_value(0.0f);
KernelUtil<device_type, T>::InitializeWithConf(ctx, momentum_initializer_conf, 0,
BnInOp2Blob("momentum"));
}
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::InitModelBlobsWithDir(
DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* momentum_blob = BnInOp2Blob("momentum");
KernelUtil<device_type, T>::InitializeWithDir(
ctx, part_id, part_num, model_load_dir, momentum_blob, "momentum",
momentum_blob->shape().At(0), momentum_blob->shape().Count(1));
}
template<DeviceType device_type, typename T> template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::UpdateModel( void MomentumMdUpdateKernel<device_type, T>::UpdateModel(
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid,
......
...@@ -12,6 +12,14 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> ...@@ -12,6 +12,14 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
MomentumMdUpdateKernel() = default; MomentumMdUpdateKernel() = default;
~MomentumMdUpdateKernel() = default; ~MomentumMdUpdateKernel() = default;
protected:
void InitModelBlobsWithRandomSeed(
DeviceCtx* ctx, std::mt19937* random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
private: private:
void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
int64_t next_model_vid, int64_t next_model_vid,
......
...@@ -168,24 +168,24 @@ void NormalizationKernel<device_type, T>::InitModelBlobsWithDir( ...@@ -168,24 +168,24 @@ void NormalizationKernel<device_type, T>::InitModelBlobsWithDir(
const auto& conf = this->op_conf().normalization_conf(); const auto& conf = this->op_conf().normalization_conf();
if (conf.scale()) { if (conf.scale()) {
Blob* gamma_blob = BnInOp2Blob("gamma"); Blob* gamma_blob = BnInOp2Blob("gamma");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, gamma_blob, KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir,
"gamma", gamma_blob->shape().At(0), gamma_blob, "gamma", gamma_blob->shape().At(0),
gamma_blob->shape().Count(1)); gamma_blob->shape().Count(1));
} }
if (conf.center()) { if (conf.center()) {
Blob* beta_blob = BnInOp2Blob("beta"); Blob* beta_blob = BnInOp2Blob("beta");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, beta_blob, KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, beta_blob,
"beta", beta_blob->shape().At(0), "beta", beta_blob->shape().At(0),
beta_blob->shape().Count(1)); beta_blob->shape().Count(1));
} }
Blob* mean_blob = BnInOp2Blob("moving_mean"); Blob* mean_blob = BnInOp2Blob("moving_mean");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, mean_blob, KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, mean_blob,
"moving_mean", mean_blob->shape().At(0), "moving_mean", mean_blob->shape().At(0),
mean_blob->shape().Count(1)); mean_blob->shape().Count(1));
Blob* variance_blob = BnInOp2Blob("moving_variance"); Blob* variance_blob = BnInOp2Blob("moving_variance");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, variance_blob, KernelUtil<device_type, T>::InitializeWithDir(
"moving_variance", variance_blob->shape().At(0), ctx, part_id, part_num, model_load_dir, variance_blob, "moving_variance",
variance_blob->shape().Count(1)); variance_blob->shape().At(0), variance_blob->shape().Count(1));
} }
template<DeviceType device_type, typename T> template<DeviceType device_type, typename T>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
namespace oneflow { namespace oneflow {
void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("momentum"); } void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollForwardModelBn("momentum"); }
void MomentumModelUpdateOp::InferBlobDescs( void MomentumModelUpdateOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册