diff --git a/oneflow/core/actor/normal_model_update_compute_actor.cpp b/oneflow/core/actor/normal_model_update_compute_actor.cpp index bd91ad66aa9f93077dc6cc159d3c403c70ef8417..443b0a3ad4c14208afa67bf45b343c6fe34e67d4 100644 --- a/oneflow/core/actor/normal_model_update_compute_actor.cpp +++ b/oneflow/core/actor/normal_model_update_compute_actor.cpp @@ -6,6 +6,12 @@ namespace oneflow { void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { model_regst_desc_id_ = Name2SoleRegstDescId("model"); const_model_regst_desc_id_ = Name2SoleRegstDescId("const_model"); + int64_t forward_model_regst_desc_id = Name2SoleRegstDescId("forward_model"); + if (forward_model_regst_desc_id != -1) { + forward_model_regst_ = GetCurWriteableRegst(forward_model_regst_desc_id); + } else { + forward_model_regst_ = nullptr; + } init_remaining_cnt_ = 0; if (model_regst_desc_id_ != -1) { init_remaining_cnt_ += 1; } if (const_model_regst_desc_id_ != -1) { @@ -13,7 +19,9 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { DecreaseActualWriteableProducedDataRegstDescNum(1); } next_model_version_id_ = 0; - related_save_model_actor_id_ = task_proto.related_save_model_task_id(); + for (int64_t model_save_related_actor_id : task_proto.related_save_model_task_ids()) { + related_save_model_actor_ids_.insert(model_save_related_actor_id); + } related_init_model_actor_id_ = task_proto.related_init_model_task_id(); pre_model_regst_ = nullptr; OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel); @@ -35,12 +43,25 @@ void NormalMdUpdtCompActor::Act() { bool need_save_model = NeedModelSave(next_model_version_id_ - 1); bool need_send_model = next_model_version_id_ < job_desc->TotalBatchNum(); AsyncSendRegstMsgToConsumer(RegstPreProcess, [&](int64_t actor_id) { - return (need_save_model && actor_id == related_save_model_actor_id_) - || (need_send_model && actor_id != related_save_model_actor_id_); + bool is_saving_related = + related_save_model_actor_ids_.find(actor_id) != related_save_model_actor_ids_.end(); + return (need_save_model && is_saving_related) || (need_send_model && !is_saving_related); }); + if (need_save_model && forward_model_regst_ != nullptr) { + AsyncSendRegstMsgToConsumer([&](Regst* regst) { return regst == forward_model_regst_; }); + } next_model_version_id_ += 1; } +int64_t NormalMdUpdtCompActor::ActNumForEachOutput(int64_t regst_desc_id) const { + const auto* job_desc = Global::Get(); + if (forward_model_regst_ != nullptr && regst_desc_id == forward_model_regst_->regst_desc_id()) { + return std::min(job_desc->TotalBatchNum() - (forward_model_regst_->act_id() + 1), + job_desc->NumOfBatchesInSnapshot()); + } + return 1; +} + void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) { if (regst_desc_id == -1) { return; } Regst* regst = GetCurWriteableRegst(regst_desc_id); @@ -48,11 +69,30 @@ void NormalMdUpdtCompActor::InitRegstBySendToFw(int64_t regst_desc_id) { Global::Get()->SendMsg(msg); } +void NormalMdUpdtCompActor::InitModelAndConstBuf() { + // TODO move the initiation of model and const model from fw op into this function + if (forward_model_regst_ == nullptr) { return; } + for (const ExecKernel& ek : exec_kernel_vec()) { + KernelCtx kernel_ctx = GenDefaultKernelCtx(); + ek.kernel->InitModelAndConstBuf(kernel_ctx, parallel_ctx(), + Global::Get()->GetReadableSnapshot(), + [&](const std::string& bn_in_op) { + const LogicalBlobId& lbi = ek.kernel->BnInOp2Lbi(bn_in_op); + Blob* blob = nullptr; + if (forward_model_regst_) { + blob = forward_model_regst_->GetBlobByLbi(lbi); + } + return blob; + }); + } +} + int NormalMdUpdtCompActor::HandlerInitModelAndConstModel(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK_EQ(msg.actor_cmd(), ActorCmd::kInitModel); InitRegstBySendToFw(model_regst_desc_id_); InitRegstBySendToFw(const_model_regst_desc_id_); + InitModelAndConstBuf(); } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { init_remaining_cnt_ -= 1; } else { diff --git a/oneflow/core/actor/normal_model_update_compute_actor.h b/oneflow/core/actor/normal_model_update_compute_actor.h index 7318fe10af7b872edc3c99e0a23ef7f3cf25b39b..306c4be5dc99e38a8c573ec159a8f4fb9103789b 100644 --- a/oneflow/core/actor/normal_model_update_compute_actor.h +++ b/oneflow/core/actor/normal_model_update_compute_actor.h @@ -18,16 +18,18 @@ class NormalMdUpdtCompActor final : public CompActor { return {true, {}}; } bool CheckOutputActId(int64_t regst_desc_id) const override; - + void InitModelAndConstBuf(); void InitRegstBySendToFw(int64_t regst_desc_id); int HandlerInitModelAndConstModel(const ActorMsg&); int HandlerSendInitialModel(const ActorMsg&); + int64_t ActNumForEachOutput(int64_t regst_desc_id) const override; int64_t model_regst_desc_id_; int64_t const_model_regst_desc_id_; + Regst* forward_model_regst_; int8_t init_remaining_cnt_; int64_t next_model_version_id_; - int64_t related_save_model_actor_id_; + HashSet related_save_model_actor_ids_; int64_t related_init_model_actor_id_; Regst* pre_model_regst_; }; diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index 7cc1e06f50cd1f31ca93f108bd12ec4fd6b8bbb0..4f255091f2035a5cacc765428e185a1e681903fd 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -560,7 +560,13 @@ NormalMdUpdtLogicalNode* LogicalGraph::BuildNormalMdUpdtAndMdSaveStruct( NormalMdUpdtLogicalNode* md_updt_logical = NewNode(); md_updt_logical->mut_parallel_desc() = fw_logical->parallel_desc(); if (Global::Get()->enable_write_snapshot()) { + // for model BuildMdSaveStruct(fw_logical, md_updt_logical); + // TODO: remove the following ugly hard coded `if' + if (Global::Get()->other_conf().train_conf().model_update_conf().has_momentum_conf()) { + // for forward_model + BuildMdSaveStruct(fw_logical, md_updt_logical); + } } return md_updt_logical; } diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.cpp b/oneflow/core/graph/normal_model_update_compute_task_node.cpp index 495743f54e49f75fd4c9a529db3163fdad1cde2b..6f8888c5f13a6a8e2715a80560125704b142d711 100644 --- a/oneflow/core/graph/normal_model_update_compute_task_node.cpp +++ b/oneflow/core/graph/normal_model_update_compute_task_node.cpp @@ -22,9 +22,12 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { int32_t max_model_regst = 1; auto model_regst = ProduceRegst("model", false, 1, max_model_regst); auto const_model_regst = ProduceRegst("const_model", false, 1, 1); + auto forward_model_regst = ProduceRegst("forward_model", false, 1, 1); ProduceRegst("processed_model_diff", false, 1, 1); ProduceRegst("data_tmp", false, 1, 1); related_init_model_task_id_ = -1; + std::list>> model_to_save{ + {"model", model_regst}, {"forward_model", forward_model_regst}}; for (TaskEdge* out_edge : out_edges()) { TaskNode* dst_node = out_edge->dst_node(); if (IsForwardTaskType(dst_node->GetTaskType()) || IsBackwardTaskType(dst_node->GetTaskType())) { @@ -36,7 +39,8 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { related_init_model_task_id_ = fw_node->task_id(); } } else { - out_edge->AddRegst("model", model_regst); + out_edge->AddRegst(model_to_save.front().first, model_to_save.front().second); + model_to_save.pop_front(); } } } @@ -75,7 +79,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { ExecEdge* exec_edge = nullptr; processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) { OperatorConf op_conf; - op_conf.set_name("md_update_" + lbi.op_name() + "_" + lbi.blob_name()); + op_conf.set_name("model_update-" + lbi.op_name() + "-" + lbi.blob_name()); op_conf.set_device_type(logical_node()->parallel_desc()->device_type()); op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name()); op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name()); @@ -112,6 +116,8 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst); model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model")); model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp")); + model_update_node->AddBnToRegstAndBindIt(&Operator::forward_model_bns, + GetProducedRegst("forward_model")); }); mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); } @@ -119,6 +125,7 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { void NormalMdUpdtCompTaskNode::LockRegsts() { GetProducedRegst("processed_model_diff")->Lock(); GetProducedRegst("data_tmp")->Lock(); + GetProducedRegst("forward_model")->Lock(); } void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { @@ -129,8 +136,7 @@ void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { } else if (IsBackwardTaskType(node->GetTaskType())) { // do nothing } else { - CHECK_EQ(task_proto->related_save_model_task_id(), -1); - task_proto->set_related_save_model_task_id(node->task_id()); + task_proto->add_related_save_model_task_ids(node->task_id()); } }); task_proto->set_related_init_model_task_id(related_init_model_task_id_); diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 43f841cb432b60e4d38bd6a81116f4ae9e7b0319..ac43f8ac0c91183719c2f87b649219a9792c054c 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -68,6 +68,6 @@ message TaskProto { // compute task optional ParallelContext parallel_ctx = 1000; // CompTask optional int64 random_seed = 1001; // ForwardCompTask - optional int64 related_save_model_task_id = 1002 [default = -1]; + repeated int64 related_save_model_task_ids = 1002; optional int64 related_init_model_task_id = 1003 [default = -1]; }; diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h index fa8f3db28dc68a83de7c0ffe96a89559a25a5a62..32cf707836af7a6713838c2223e7ff802708b641 100644 --- a/oneflow/core/kernel/kernel.h +++ b/oneflow/core/kernel/kernel.h @@ -46,7 +46,6 @@ class Kernel { virtual void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir, std::function BnInOp2Blob) const {} - virtual ActivationType GetActivationType() const { return ActivationType::kNone; } virtual void Forward(const KernelCtx& ctx, diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cpp b/oneflow/core/kernel/momentum_model_update_kernel.cpp index 7768402737851e6d95e9b5985c2ad9f99f5e9784..079de1316a1839c8095ac59cada473f03655b983 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cpp +++ b/oneflow/core/kernel/momentum_model_update_kernel.cpp @@ -3,6 +3,26 @@ namespace oneflow { +template +void MomentumMdUpdateKernel::InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937* random_seed_gen, + std::function BnInOp2Blob) const { + InitializerConf momentum_initializer_conf; + momentum_initializer_conf.mutable_constant_conf()->set_value(0.0f); + KernelUtil::InitializeWithConf(ctx, momentum_initializer_conf, 0, + BnInOp2Blob("momentum")); +} + +template +void MomentumMdUpdateKernel::InitModelBlobsWithDir( + DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir, + std::function BnInOp2Blob) const { + Blob* momentum_blob = BnInOp2Blob("momentum"); + KernelUtil::InitializeWithDir( + ctx, part_id, part_num, model_load_dir, momentum_blob, "momentum", + momentum_blob->shape().At(0), momentum_blob->shape().Count(1)); +} + template void MomentumMdUpdateKernel::UpdateModel( DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, diff --git a/oneflow/core/kernel/momentum_model_update_kernel.h b/oneflow/core/kernel/momentum_model_update_kernel.h index 0d978d92601922f9facab2bb79b424be1c95d8a5..7b42d56f51974e0a895f785046e2c172ca2e6f3d 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.h +++ b/oneflow/core/kernel/momentum_model_update_kernel.h @@ -12,6 +12,14 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel MomentumMdUpdateKernel() = default; ~MomentumMdUpdateKernel() = default; + protected: + void InitModelBlobsWithRandomSeed( + DeviceCtx* ctx, std::mt19937* random_seed_gen, + std::function BnInOp2Blob) const override; + void InitModelBlobsWithDir(DeviceCtx* ctx, int32_t part_id, int32_t part_num, + const std::string& model_load_dir, + std::function BnInOp2Blob) const override; + private: void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, diff --git a/oneflow/core/kernel/normalization_kernel.cpp b/oneflow/core/kernel/normalization_kernel.cpp index 80e56cf6b7c655e707e80bc04e44c4536e521178..e2ad528ea30f5fee3d676bfca2456e3c57c6c379 100644 --- a/oneflow/core/kernel/normalization_kernel.cpp +++ b/oneflow/core/kernel/normalization_kernel.cpp @@ -168,24 +168,24 @@ void NormalizationKernel::InitModelBlobsWithDir( const auto& conf = this->op_conf().normalization_conf(); if (conf.scale()) { Blob* gamma_blob = BnInOp2Blob("gamma"); - KernelUtil::InitializeWithDir(ctx, 0, part_num, model_load_dir, gamma_blob, - "gamma", gamma_blob->shape().At(0), + KernelUtil::InitializeWithDir(ctx, part_id, part_num, model_load_dir, + gamma_blob, "gamma", gamma_blob->shape().At(0), gamma_blob->shape().Count(1)); } if (conf.center()) { Blob* beta_blob = BnInOp2Blob("beta"); - KernelUtil::InitializeWithDir(ctx, 0, part_num, model_load_dir, beta_blob, + KernelUtil::InitializeWithDir(ctx, part_id, part_num, model_load_dir, beta_blob, "beta", beta_blob->shape().At(0), beta_blob->shape().Count(1)); } Blob* mean_blob = BnInOp2Blob("moving_mean"); - KernelUtil::InitializeWithDir(ctx, 0, part_num, model_load_dir, mean_blob, + KernelUtil::InitializeWithDir(ctx, part_id, part_num, model_load_dir, mean_blob, "moving_mean", mean_blob->shape().At(0), mean_blob->shape().Count(1)); Blob* variance_blob = BnInOp2Blob("moving_variance"); - KernelUtil::InitializeWithDir(ctx, 0, part_num, model_load_dir, variance_blob, - "moving_variance", variance_blob->shape().At(0), - variance_blob->shape().Count(1)); + KernelUtil::InitializeWithDir( + ctx, part_id, part_num, model_load_dir, variance_blob, "moving_variance", + variance_blob->shape().At(0), variance_blob->shape().Count(1)); } template diff --git a/oneflow/core/operator/momentum_model_update_op.cpp b/oneflow/core/operator/momentum_model_update_op.cpp index 83a29b04cc6313da862fe2d6256b5c98b48918f4..f39cc3892e34b9ac032f8a50d00d1969c81c487d 100644 --- a/oneflow/core/operator/momentum_model_update_op.cpp +++ b/oneflow/core/operator/momentum_model_update_op.cpp @@ -2,7 +2,7 @@ namespace oneflow { -void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollDataTmpBn("momentum"); } +void MomentumModelUpdateOp::MdUpdtVirtualInitFromOpConf() { EnrollForwardModelBn("momentum"); } void MomentumModelUpdateOp::InferBlobDescs( std::function GetBlobDesc4BnInOp,