提交 3d5244c8 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

moving model (#1234)

* moving model

* moving_model => forward_model

* add todo commit

* two model save node

* let md_updt actor handle forward_model

* remove useless code

* rename local variable


Former-commit-id: baa146bd
上级 33868c01
......@@ -6,6 +6,12 @@ namespace oneflow {
void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
model_regst_desc_id_ = Name2SoleRegstDescId("model");
const_model_regst_desc_id_ = Name2SoleRegstDescId("const_model");
int64_t forward_model_regst_desc_id = Name2SoleRegstDescId("forward_model");
if (forward_model_regst_desc_id != -1) {
forward_model_regst_ = GetCurWriteableRegst(forward_model_regst_desc_id);
} else {
forward_model_regst_ = nullptr;
}
init_remaining_cnt_ = 0;
if (model_regst_desc_id_ != -1) { init_remaining_cnt_ += 1; }
if (const_model_regst_desc_id_ != -1) {
......@@ -13,7 +19,9 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
DecreaseActualWriteableProducedDataRegstDescNum(1);
}
next_model_version_id_ = 0;
related_save_model_actor_id_ = task_proto.related_save_model_task_id();
for (int64_t model_save_related_actor_id : task_proto.related_save_model_task_ids()) {
related_save_model_actor_ids_.insert(model_save_related_actor_id);
}
related_init_model_actor_id_ = task_proto.related_init_model_task_id();
pre_model_regst_ = nullptr;
OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel);
......@@ -35,12 +43,25 @@ void NormalMdUpdtCompActor::Act() {
bool need_save_model = NeedModelSave(next_model_version_id_ - 1);
bool need_send_model = next_model_version_id_ < job_desc->TotalBatchNum();
AsyncSendRegstMsgToConsumer(RegstPreProcess, [&](int64_t actor_id) {
return (need_save_model && actor_id == related_save_model_actor_id_)
|| (need_send_model && actor_id != related_save_model_actor_id_);
bool is_saving_related =
related_save_model_actor_ids_.find(actor_id) != related_save_model_actor_ids_.end();
return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related);
});
if (need_save_model && forward_model_regst_ != nullptr) {
AsyncSendRegstMsgToConsumer([&](Regst* regst) { return regst == forward_model_regst_; });
}
next_model_version_id_ += 1;
}
int64_t NormalMdUpdtCompActor::ActNumForEachOutput(int64_t regst_desc_id) const {
const auto* job_desc = Global<JobDesc>::Get();
if (forward_model_regst_ != nullptr && regst_desc_id == forward_model_regst_->regst_desc_id()) {
return std::min<int64_t>(job_desc->TotalBatchNum() - (forward_model_regst_->act_id() + 1),
job_desc->NumOfBatchesInSnapshot());
}
return 1;
}
void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) {
if (regst_desc_id == -1) { return; }
Regst* regst = GetCurWriteableRegst(regst_desc_id);
......@@ -48,11 +69,30 @@ void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) {
Global<ActorMsgBus>::Get()->SendMsg(msg);
}
void NormalMdUpdtCompActor::InitModelAndConstBuf() {
// TODO move the initiation of model and const model from fw op into this function
if (forward_model_regst_ == nullptr) { return; }
for (const ExecKernel& ek : exec_kernel_vec()) {
KernelCtx kernel_ctx = GenDefaultKernelCtx();
ek.kernel->InitModelAndConstBuf(kernel_ctx, parallel_ctx(),
Global<SnapshotMgr>::Get()->GetReadableSnapshot(),
[&](const std::string& bn_in_op) {
const LogicalBlobId& lbi = ek.kernel->BnInOp2Lbi(bn_in_op);
Blob* blob = nullptr;
if (forward_model_regst_) {
blob = forward_model_regst_->GetBlobByLbi(lbi);
}
return blob;
});
}
}
int NormalMdUpdtCompActor::HandlerInitModelAndConstModel(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kInitModel);
InitRegstBySendToFw(model_regst_desc_id_);
InitRegstBySendToFw(const_model_regst_desc_id_);
InitModelAndConstBuf();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
init_remaining_cnt_ -= 1;
} else {
......
......@@ -18,16 +18,18 @@ class NormalMdUpdtCompActor final : public CompActor {
return {true, {}};
}
bool CheckOutputActId(int64_t regst_desc_id) const override;
void InitModelAndConstBuf();
void InitRegstBySendToFw(int64_t regst_desc_id);
int HandlerInitModelAndConstModel(const ActorMsg&);
int HandlerSendInitialModel(const ActorMsg&);
int64_t ActNumForEachOutput(int64_t regst_desc_id) const override;
int64_t model_regst_desc_id_;
int64_t const_model_regst_desc_id_;
Regst* forward_model_regst_;
int8_t init_remaining_cnt_;
int64_t next_model_version_id_;
int64_t related_save_model_actor_id_;
HashSet<int64_t> related_save_model_actor_ids_;
int64_t related_init_model_actor_id_;
Regst* pre_model_regst_;
};
......
......@@ -560,7 +560,13 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct(
NormalMdUpdtLogicalNode* md_updt_logical = NewNode<NormalMdUpdtLogicalNode>();
md_updt_logical->mut_parallel_desc() = fw_logical->parallel_desc();
if (Global<JobDesc>::Get()->enable_write_snapshot()) {
// for model
BuildMdSaveStruct(fw_logical, md_updt_logical);
// TODO: remove the following ugly hard coded `if'
if (Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()) {
// for forward_model
BuildMdSaveStruct(fw_logical, md_updt_logical);
}
}
return md_updt_logical;
}
......
......@@ -22,9 +22,12 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
int32_t max_model_regst = 1;
auto model_regst = ProduceRegst("model", false, 1, max_model_regst);
auto const_model_regst = ProduceRegst("const_model", false, 1, 1);
auto forward_model_regst = ProduceRegst("forward_model", false, 1, 1);
ProduceRegst("processed_model_diff", false, 1, 1);
ProduceRegst("data_tmp", false, 1, 1);
related_init_model_task_id_ = -1;
std::list<std::pair<std::string, std::shared_ptr<RegstDesc>>> model_to_save{
{"model", model_regst}, {"forward_model", forward_model_regst}};
for (TaskEdge* out_edge : out_edges()) {
TaskNode* dst_node = out_edge->dst_node();
if (IsForwardTaskType(dst_node->GetTaskType()) || IsBackwardTaskType(dst_node->GetTaskType())) {
......@@ -36,7 +39,8 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
related_init_model_task_id_ = fw_node->task_id();
}
} else {
out_edge->AddRegst("model", model_regst);
out_edge->AddRegst(model_to_save.front().first, model_to_save.front().second);
model_to_save.pop_front();
}
}
}
......@@ -75,7 +79,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
ExecEdge* exec_edge = nullptr;
processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) {
OperatorConf op_conf;
op_conf.set_name("md_update_" + lbi.op_name() + "_" + lbi.blob_name());
op_conf.set_name("model_update-" + lbi.op_name() + "-" + lbi.blob_name());
op_conf.set_device_type(logical_node()->parallel_desc()->device_type());
op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name());
op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name());
......@@ -112,6 +116,8 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst);
model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model"));
model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
model_update_node->AddBnToRegstAndBindIt(&Operator::forward_model_bns,
GetProducedRegst("forward_model"));
});
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
}
......@@ -119,6 +125,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
void NormalMdUpdtCompTaskNode::LockRegsts() {
GetProducedRegst("processed_model_diff")->Lock();
GetProducedRegst("data_tmp")->Lock();
GetProducedRegst("forward_model")->Lock();
}
void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
......@@ -129,8 +136,7 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
} else if (IsBackwardTaskType(node->GetTaskType())) {
// do nothing
} else {
CHECK_EQ(task_proto->related_save_model_task_id(), -1);
task_proto->set_related_save_model_task_id(node->task_id());
task_proto->add_related_save_model_task_ids(node->task_id());
}
});
task_proto->set_related_init_model_task_id(related_init_model_task_id_);
......
......@@ -68,6 +68,6 @@ message TaskProto {
// compute task
optional ParallelContext parallel_ctx = 1000; // CompTask
optional int64 random_seed = 1001; // ForwardCompTask
optional int64 related_save_model_task_id = 1002 [default = -1];
repeated int64 related_save_model_task_ids = 1002;
optional int64 related_init_model_task_id = 1003 [default = -1];
};
......@@ -46,7 +46,6 @@ class Kernel {
virtual void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {}
virtual ActivationType GetActivationType() const { return ActivationType::kNone; }
virtual void Forward(const KernelCtx& ctx,
......
......@@ -3,6 +3,26 @@
namespace oneflow {
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::InitModelBlobsWithRandomSeed(
DeviceCtx* ctx, std::mt19937* random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
InitializerConf momentum_initializer_conf;
momentum_initializer_conf.mutable_constant_conf()->set_value(0.0f);
KernelUtil<device_type, T>::InitializeWithConf(ctx, momentum_initializer_conf, 0,
BnInOp2Blob("momentum"));
}
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::InitModelBlobsWithDir(
DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* momentum_blob = BnInOp2Blob("momentum");
KernelUtil<device_type, T>::InitializeWithDir(
ctx, part_id, part_num, model_load_dir, momentum_blob, "momentum",
momentum_blob->shape().At(0), momentum_blob->shape().Count(1));
}
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::UpdateModel(
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid,
......
......@@ -12,6 +12,14 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
MomentumMdUpdateKernel() = default;
~MomentumMdUpdateKernel() = default;
protected:
void InitModelBlobsWithRandomSeed(
DeviceCtx* ctx, std::mt19937* random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num,
const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
private:
void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
int64_t next_model_vid,
......
......@@ -168,24 +168,24 @@ void NormalizationKernel<device_type, T>::InitModelBlobsWithDir(
const auto& conf = this->op_conf().normalization_conf();
if (conf.scale()) {
Blob* gamma_blob = BnInOp2Blob("gamma");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, gamma_blob,
"gamma", gamma_blob->shape().At(0),
KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir,
gamma_blob, "gamma", gamma_blob->shape().At(0),
gamma_blob->shape().Count(1));
}
if (conf.center()) {
Blob* beta_blob = BnInOp2Blob("beta");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, beta_blob,
KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, beta_blob,
"beta", beta_blob->shape().At(0),
beta_blob->shape().Count(1));
}
Blob* mean_blob = BnInOp2Blob("moving_mean");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, mean_blob,
KernelUtil<device_type, T>::InitializeWithDir(ctx, part_id, part_num, model_load_dir, mean_blob,
"moving_mean", mean_blob->shape().At(0),
mean_blob->shape().Count(1));
Blob* variance_blob = BnInOp2Blob("moving_variance");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, variance_blob,
"moving_variance", variance_blob->shape().At(0),
variance_blob->shape().Count(1));
KernelUtil<device_type, T>::InitializeWithDir(
ctx, part_id, part_num, model_load_dir, variance_blob, "moving_variance",
variance_blob->shape().At(0), variance_blob->shape().Count(1));
}
template<DeviceType device_type, typename T>
......
......@@ -2,7 +2,7 @@
namespace oneflow {
void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("momentum"); }
void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollForwardModelBn("momentum"); }
void MomentumModelUpdateOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册