diff --git a/oneflow/core/actor/normal_model_update_compute_actor.cpp b/oneflow/core/actor/normal_model_update_compute_actor.cpp index 5d9c3d8cf09b5f8e1c63bb067120ae5eb1c91d82..bd91ad66aa9f93077dc6cc159d3c403c70ef8417 100644 --- a/oneflow/core/actor/normal_model_update_compute_actor.cpp +++ b/oneflow/core/actor/normal_model_update_compute_actor.cpp @@ -27,9 +27,7 @@ void NormalMdUpdtCompActor::Act() { Regst* cur_model_regst = GetCurWriteableRegst(model_regst_desc_id_); cur_model_regst->set_model_version_id(next_model_version_id_); KernelCtx kernel_ctx = GenDefaultKernelCtx(); - std::tuple other_val(next_model_version_id_, - pre_model_regst_->packed_blob()); - kernel_ctx.other = &other_val; + kernel_ctx.other = &next_model_version_id_; pre_model_regst_ = cur_model_regst; AsyncLaunchKernel(kernel_ctx); const JobDesc* job_desc = Global::Get(); diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index a73bac014cc5437905cab1df0d02d059844ce900..38894d9ee2192866c482398d609d62f481037837 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -525,15 +525,11 @@ void LogicalGraph::AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode* void LogicalGraph::SetupNormalMdUpdtOp() { ForEachLogicalNode([](NormalMdUpdtLogicalNode* node) { if (node->in_edges().size() < 1) { return; } + // Add shared_model_diff_add_op OperatorConf op_conf; - op_conf.set_name("md_update_" + NewUniqueId()); + op_conf.set_name("md_diff_add_" + NewUniqueId()); op_conf.set_device_type(node->parallel_desc()->device_type()); - NormalModelUpdateOpConf* mdupdt_conf = op_conf.mutable_normal_mdupdt_conf(); - const JobDesc* job_desc = Global::Get(); - if (Global::Get()->IsTrain()) { - *(mdupdt_conf->mutable_user_conf()) = job_desc->other_conf().train_conf().model_update_conf(); - } - mdupdt_conf->set_in_num(node->in_edges().size()); + op_conf.mutable_shared_model_diff_add_conf()->set_in_num(node->in_edges().size()); node->mut_op_vec() = {ConstructOp(op_conf)}; }); } diff --git a/oneflow/core/graph/normal_model_update_compute_task_node.cpp b/oneflow/core/graph/normal_model_update_compute_task_node.cpp index 02ec24fa9d51192c9c1c1eefaff9b4a60e4fc4da..1ec344cfa154f0c1708eda80387d789b817f431c 100644 --- a/oneflow/core/graph/normal_model_update_compute_task_node.cpp +++ b/oneflow/core/graph/normal_model_update_compute_task_node.cpp @@ -22,6 +22,7 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() { int32_t max_model_regst = 1; auto model_regst = ProduceRegst("model", false, 1, max_model_regst); auto const_model_regst = ProduceRegst("const_model", false, 1, 1); + ProduceRegst("processed_model_diff", false, 1, 1); ProduceRegst("data_tmp", false, 1, 1); related_init_model_task_id_ = -1; for (TaskEdge* out_edge : out_edges()) { @@ -56,18 +57,52 @@ bool NormalMdUpdtCompTaskNode::IsReadyForBuild() { void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() { if (!IsTrainable()) { return; } - ExecNode* node = mut_exec_gph().NewNode(); - node->mut_op() = logical_node()->SoleOp(); + ExecNode* shared_model_diff_add_node = mut_exec_gph().NewNode(); + shared_model_diff_add_node->mut_op() = logical_node()->SoleOp(); size_t ibn_idx = 0; for (const auto& pair : consumed_regsts()) { - node->BindBnWithRegst(node->op()->input_bns().Get(ibn_idx++), pair.second.front()); + shared_model_diff_add_node->BindBnWithRegst( + shared_model_diff_add_node->op()->input_bns().Get(ibn_idx++), pair.second.front()); } - node->BindBnWithRegst(node->op()->SoleObn(), GetProducedRegst("model")); - node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp")); - node->InferBlobDescs(nullptr); + std::shared_ptr processed_model_diff_regst = GetProducedRegst("processed_model_diff"); + shared_model_diff_add_node->BindBnWithRegst(logical_node()->SoleOp()->SoleObn(), + processed_model_diff_regst); + // "model" regst is already bound with lbis and locked by the corresponding + // NormalForwardCompTaskNode + processed_model_diff_regst->CopyBlobDescFrom(GetProducedRegst("model").get()); + + ExecNode* model_update_node = nullptr; + ExecEdge* exec_edge = nullptr; + processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) { + OperatorConf op_conf; + op_conf.set_name("md_update_" + lbi.op_name() + "_" + lbi.blob_name()); + op_conf.set_device_type(logical_node()->parallel_desc()->device_type()); + op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name()); + op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name()); + if (Global::Get()->IsTrain()) { + *(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) = + Global::Get()->other_conf().train_conf().model_update_conf(); + } + std::shared_ptr model_update_op = ConstructOp(op_conf); + model_update_node = mut_exec_gph().NewNode(); + model_update_node->mut_op() = model_update_op; + exec_edge = mut_exec_gph().NewEdge(); + exec_edge->set_lbi(lbi); + exec_edge->mut_src_bn() = lbi.blob_name(); + exec_edge->mut_dst_bn() = model_update_op->SoleIbn(); + Connect(shared_model_diff_add_node, exec_edge, model_update_node); + + model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst); + model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model")); + model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp")); + }); + mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); }); } -void NormalMdUpdtCompTaskNode::LockRegsts() { GetProducedRegst("data_tmp")->Lock(); } +void NormalMdUpdtCompTaskNode::LockRegsts() { + GetProducedRegst("processed_model_diff")->Lock(); + GetProducedRegst("data_tmp")->Lock(); +} void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) { CompTaskNode::ToProto(task_proto); diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cpp b/oneflow/core/kernel/momentum_model_update_kernel.cpp index 3ebbaf5f53d9b7fc68a369080976cff43f6a7e4d..7768402737851e6d95e9b5985c2ad9f99f5e9784 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cpp +++ b/oneflow/core/kernel/momentum_model_update_kernel.cpp @@ -5,9 +5,9 @@ namespace oneflow { template void MomentumMdUpdateKernel::UpdateModel( - DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob, - const Blob* model_diff_blob, int64_t next_model_vid, + DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, std::function BnInOp2Blob) const { + const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); Blob* momentum_blob = BnInOp2Blob("momentum"); float beta = this->op_conf().normal_mdupdt_conf().user_conf().momentum_conf().beta(); @@ -15,19 +15,18 @@ void MomentumMdUpdateKernel::UpdateModel( MomentumMdUpdateKernelUtil::UpdateModel( ctx, model_blob->shape().elem_cnt(), batch_size, static_cast(beta), learning_rate, l1, l2, - model_diff_blob->dptr(), pre_model_blob->dptr(), momentum_blob->mut_dptr(), - model_blob->mut_dptr()); + model_diff_blob->dptr(), model_blob->mut_dptr(), momentum_blob->mut_dptr()); } template class MomentumMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T beta, T learning_rate, T l1, - T l2, const T* model_diff, const T* pre_model, T* momentum, T* model) { + T l2, const T* model_diff, T* model, T* momentum) { for (int64_t i = 0; i != n; ++i) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); momentum[i] = beta * momentum[i] - learning_rate * reg_diff; - model[i] = pre_model[i] + momentum[i]; + model[i] = model[i] + momentum[i]; } } }; diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cu b/oneflow/core/kernel/momentum_model_update_kernel.cu index 47c81a2a9e5c664725793b8f27616daa04527808..71a89fe7440e7b543a87de95b773e9cfc95267dd 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cu +++ b/oneflow/core/kernel/momentum_model_update_kernel.cu @@ -8,11 +8,11 @@ namespace { template __global__ void UpdateModelGpu(int64_t n, int64_t batch_size, T beta, T learning_rate, T l1, T l2, - const T* model_diff, const T* pre_model, T* momentum, T* model) { + const T* model_diff, T* model, T* momentum) { CUDA_1D_KERNEL_LOOP(i, n) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); momentum[i] = beta * momentum[i] - learning_rate * reg_diff; - model[i] = pre_model[i] + momentum[i]; + model[i] = model[i] + momentum[i]; } } @@ -22,10 +22,9 @@ template class MomentumMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, int64_t batch_size, T beta, T learning_rate, - const T l1, const T l2, const T* model_diff, const T* pre_model, - T* momentum, T* model) { + const T l1, const T l2, const T* model_diff, T* model, T* momentum) { UpdateModelGpu<<cuda_stream()>>>( - n, batch_size, beta, learning_rate, l1, l2, model_diff, pre_model, momentum, model); + n, batch_size, beta, learning_rate, l1, l2, model_diff, model, momentum); } }; diff --git a/oneflow/core/kernel/momentum_model_update_kernel.h b/oneflow/core/kernel/momentum_model_update_kernel.h index 8598026b9cf23bc3bf4b43c804b2e29facfe0a97..0d978d92601922f9facab2bb79b424be1c95d8a5 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.h +++ b/oneflow/core/kernel/momentum_model_update_kernel.h @@ -14,7 +14,7 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel private: void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, - const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid, + int64_t next_model_vid, std::function BnInOp2Blob) const override; }; @@ -22,7 +22,7 @@ template class MomentumMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T beta, T learning_rate, T l1, - T l2, const T* model_diff, const T* pre_model, T* momentum, T* model); + T l2, const T* model_diff, T* model, T* momentum); }; DECLARE_MDUPDT_KERNEL_CREATOR(Momentum); diff --git a/oneflow/core/kernel/naive_model_update_kernel.cpp b/oneflow/core/kernel/naive_model_update_kernel.cpp index 6ab441d317150b6cc8f9969f9c0433c193bf85c6..1255449fc52c48b0e86d2e6a4a59328e5d05461b 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.cpp +++ b/oneflow/core/kernel/naive_model_update_kernel.cpp @@ -5,25 +5,24 @@ namespace oneflow { template void NaiveMdUpdateKernel::UpdateModel( - DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob, - const Blob* model_diff_blob, int64_t next_model_vid, + DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, std::function BnInOp2Blob) const { + const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); - if (pre_model_blob != model_blob) { model_blob->CopyDataContentFrom(ctx, pre_model_blob); } // model = model - alpha * model_diff NaiveMdUpdateKernelUtil::UpdateModel( ctx, model_blob->shape().elem_cnt(), batch_size, learning_rate, l1, l2, - model_diff_blob->dptr(), pre_model_blob->dptr(), model_blob->mut_dptr()); + model_diff_blob->dptr(), model_blob->mut_dptr()); } template class NaiveMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, const int64_t n, int64_t batch_size, T learning_rate, T l1, - T l2, const T* model_diff, const T* pre_model, T* model) { + T l2, const T* model_diff, T* model) { for (int64_t i = 0; i != n; ++i) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); - model[i] = pre_model[i] - learning_rate * reg_diff; + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); + model[i] = model[i] - learning_rate * reg_diff; } } }; diff --git a/oneflow/core/kernel/naive_model_update_kernel.cu b/oneflow/core/kernel/naive_model_update_kernel.cu index 76407c6abee22d3ea404dd42105b4474c8381130..fbdebae14bd00430c428a55b86aeef0e25d97ef8 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.cu +++ b/oneflow/core/kernel/naive_model_update_kernel.cu @@ -8,10 +8,10 @@ namespace { template __global__ void UpdateModelGpu(const int64_t n, int64_t batch_size, T learning_rate, T l1, T l2, - const T* model_diff, const T* pre_model, T* model) { + const T* model_diff, T* model) { CUDA_1D_KERNEL_LOOP(i, n) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); - model[i] = pre_model[i] - learning_rate * reg_diff; + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); + model[i] = model[i] - learning_rate * reg_diff; } } @@ -21,9 +21,9 @@ template class NaiveMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx* ctx, const int64_t n, int64_t batch_size, T learning_rate, - T l1, T l2, const T* model_diff, const T* pre_model, T* model) { + T l1, T l2, const T* model_diff, T* model) { UpdateModelGpu<<cuda_stream()>>>( - n, batch_size, learning_rate, l1, l2, model_diff, pre_model, model); + n, batch_size, learning_rate, l1, l2, model_diff, model); } }; diff --git a/oneflow/core/kernel/naive_model_update_kernel.h b/oneflow/core/kernel/naive_model_update_kernel.h index ab06530796fbd6305d4dc01797ceaf10e7273266..510f834670352bee9a3698336d8473a3b4d99005 100644 --- a/oneflow/core/kernel/naive_model_update_kernel.h +++ b/oneflow/core/kernel/naive_model_update_kernel.h @@ -14,7 +14,7 @@ class NaiveMdUpdateKernel final : public NormalMdUpdateKernel { private: void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, - const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid, + int64_t next_model_vid, std::function BnInOp2Blob) const override; }; @@ -22,7 +22,7 @@ template class NaiveMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T l1, T l2, - const T* model_diff, const T* pre_model, T* model); + const T* model_diff, T* model); }; DECLARE_MDUPDT_KERNEL_CREATOR(Naive); diff --git a/oneflow/core/kernel/normal_model_update_kernel.cpp b/oneflow/core/kernel/normal_model_update_kernel.cpp index 62e4dc27a4eb61c729594e1041004556a9e93f9b..7d6130606eaa6ed72424b83bd61bc656bdef3d62 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.cpp +++ b/oneflow/core/kernel/normal_model_update_kernel.cpp @@ -8,8 +8,7 @@ namespace oneflow { template void NormalMdUpdateKernel::Forward( const KernelCtx& ctx, std::function BnInOp2Blob) const { - auto tpl = reinterpret_cast*>(ctx.other); - int64_t next_model_vid = std::get<0>(*tpl); + int64_t next_model_vid = *reinterpret_cast(ctx.other); int64_t cur_batch_num = next_model_vid - 1; const NormalModelUpdateOpUserConf& conf = this->op_conf().normal_mdupdt_conf().user_conf(); double learning_rate = conf.learning_rate(); @@ -19,19 +18,11 @@ void NormalMdUpdateKernel::Forward( learning_rate = GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num); } - const OpAttribute& op_attribute = this->kernel_conf().op_attribute(); - Blob* in_0 = BnInOp2Blob(op_attribute.input_bns(0)); - FOR_RANGE(size_t, i, 1, op_attribute.input_bns().size()) { - Blob* in_i = BnInOp2Blob(op_attribute.input_bns(i)); - KernelUtil::Axpy(ctx.device_ctx, in_0->shape().elem_cnt(), static_cast(1.0), - in_i->dptr(), 1, in_0->mut_dptr(), 1); - } - int64_t batch_size = Global::Get()->BatchSize(); float l1 = Global::Get()->L1(); float l2 = Global::Get()->L2(); UpdateModel(ctx.device_ctx, batch_size, static_cast(learning_rate), static_cast(l1), - static_cast(l2), std::get<1>(*tpl), in_0, next_model_vid, BnInOp2Blob); + static_cast(l2), next_model_vid, BnInOp2Blob); } #define INSTANTIATE_KERNEL(device_type, data_type_pair) \ diff --git a/oneflow/core/kernel/normal_model_update_kernel.h b/oneflow/core/kernel/normal_model_update_kernel.h index 3692dde651fc3a819c004335558243d641fbb12c..83ba8ca62dce23d89959d642b0fd9d783924376a 100644 --- a/oneflow/core/kernel/normal_model_update_kernel.h +++ b/oneflow/core/kernel/normal_model_update_kernel.h @@ -17,7 +17,6 @@ class NormalMdUpdateKernel : public KernelIf { protected: NormalMdUpdateKernel() = default; virtual void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, - const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid, std::function BnInOp2Blob) const = 0; }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp index de51e578be8543d1e25a00b7fba65d5919a76ee6..be39d547f7f52d127a6ad40bf3b9cef9bedfd902 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.cpp +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cpp @@ -5,9 +5,9 @@ namespace oneflow { template void RMSPropMdUpdateKernel::UpdateModel( - DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob, - const Blob* model_diff_blob, int64_t next_model_vid, + DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid, std::function BnInOp2Blob) const { + const Blob* model_diff_blob = BnInOp2Blob("model_diff"); Blob* model_blob = BnInOp2Blob("model"); Blob* mean_square_blob = BnInOp2Blob("mean_square"); const RMSPropModelUpdateConf& conf = @@ -17,20 +17,19 @@ void RMSPropMdUpdateKernel::UpdateModel( RMSPropMdUpdateKernelUtil::UpdateModel( ctx, model_blob->shape().elem_cnt(), batch_size, learning_rate, static_cast(decay_rate), - static_cast(conf.epsilon()), l1, l2, pre_model_blob->dptr(), model_blob->mut_dptr(), - mean_square_blob->mut_dptr(), model_diff_blob->dptr()); + static_cast(conf.epsilon()), l1, l2, model_diff_blob->dptr(), model_blob->mut_dptr(), + mean_square_blob->mut_dptr()); } template class RMSPropMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T decay_rate, - T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square, - const T* model_diff) { + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) { for (int64_t i = 0; i < n; ++i) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i]; - model[i] = pre_model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); + model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); } } }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.cu b/oneflow/core/kernel/rmsprop_model_update_kernel.cu index c3121f2da3f14ed80bfb676ff7e7543b45e95041..2412562afc6bb35380b122618e3a64046a27a1b3 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.cu +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.cu @@ -8,12 +8,12 @@ namespace { template __global__ void UpdateModelGpu(int64_t n, int64_t batch_size, T learning_rate, T decay_rate, - T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square, - const T* model_diff) { + T epsilon, T l1, T l2, const T* model_diff, T* model, + T* mean_square) { CUDA_1D_KERNEL_LOOP(i, n) { - T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]); + T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]); mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i]; - model[i] = pre_model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); + model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon); } } @@ -23,11 +23,10 @@ template class RMSPropMdUpdateKernelUtil final { public: static void UpdateModel(DeviceCtx* ctx, int64_t n, int64_t batch_size, T learning_rate, - T decay_rate, T epsilon, T l1, T l2, const T* pre_model, T* model, - T* mean_square, const T* model_diff) { + T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model, + T* mean_square) { UpdateModelGpu<<cuda_stream()>>>( - n, batch_size, learning_rate, decay_rate, epsilon, l1, l2, pre_model, model, mean_square, - model_diff); + n, batch_size, learning_rate, decay_rate, epsilon, l1, l2, model_diff, model, mean_square); } }; diff --git a/oneflow/core/kernel/rmsprop_model_update_kernel.h b/oneflow/core/kernel/rmsprop_model_update_kernel.h index 797a350f32ae5193bfe92e9bf7f1e43423f3dfcc..a555a58927a794dfb21ff08829583011423a60df 100644 --- a/oneflow/core/kernel/rmsprop_model_update_kernel.h +++ b/oneflow/core/kernel/rmsprop_model_update_kernel.h @@ -14,7 +14,7 @@ class RMSPropMdUpdateKernel final : public NormalMdUpdateKernel private: void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, - const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid, + int64_t next_model_vid, std::function BnInOp2Blob) const override; }; @@ -22,10 +22,9 @@ template class RMSPropMdUpdateKernelUtil final { public: // mean_square = (1 - decay_rate) * model_diff ^ 2 + decay_rate * mean_square - // model = pre_model - learning_rate * model_diff / sqrt(mean_square + epsilon) + // model = model - learning_rate * model_diff / sqrt(mean_square + epsilon) static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T decay_rate, - T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square, - const T* model_diff); + T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square); }; DECLARE_MDUPDT_KERNEL_CREATOR(RMSProp); diff --git a/oneflow/core/kernel/shared_model_diff_add_kernel.cpp b/oneflow/core/kernel/shared_model_diff_add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69c41de070602e9fef231b5f394ed792a1bb1ae9 --- /dev/null +++ b/oneflow/core/kernel/shared_model_diff_add_kernel.cpp @@ -0,0 +1,36 @@ +#include "oneflow/core/kernel/shared_model_diff_add_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/kernel/kernel_common.hpp" + +namespace oneflow { + +template +void SharedModelDiffAddKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const PbRpf& ibns = this->op_attribute().input_bns(); + size_t in_num = ibns.size(); + if (in_num == 0) return; + Blob* out_blob = BnInOp2Blob(this->op_attribute().output_bns(0)); + auto in_blob = [&](int32_t idx) { return BnInOp2Blob(this->op_attribute().input_bns(idx)); }; + static const int kWidth = 8; + int r = in_num % kWidth; + if (r) { + tuple_switch(r, tp_, + AdditionFunction{ + out_blob, std::move(BnInOp2Blob), ctx.device_ctx, 0, this}); + } + for (; r < in_num; r += kWidth) { + Addition(ctx.device_ctx, out_blob, out_blob, in_blob(r), in_blob(r + 1), + in_blob(r + 2), in_blob(r + 3), in_blob(r + 4), in_blob(r + 5), + in_blob(r + 6), in_blob(r + 7)); + } +} + +template +const PbMessage& SharedModelDiffAddKernel::GetCustomizedOpConf() const { + return this->op_conf().shared_model_diff_add_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSharedModelDiffAddConf, SharedModelDiffAddKernel, + ARITHMETIC_DATA_TYPE_SEQ); +} // namespace oneflow diff --git a/oneflow/core/kernel/shared_model_diff_add_kernel.h b/oneflow/core/kernel/shared_model_diff_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..568bc9a7b2b860c08e69fd9889ae53de7e0d7dc0 --- /dev/null +++ b/oneflow/core/kernel/shared_model_diff_add_kernel.h @@ -0,0 +1,25 @@ +#ifndef ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" + +namespace oneflow { + +template +class SharedModelDiffAddKernel final : public KernelIf { + public: + OF_DISALLOW_COPY_AND_MOVE(SharedModelDiffAddKernel); + SharedModelDiffAddKernel() = default; + ~SharedModelDiffAddKernel() = default; + + private: + void ForwardDataContent(const KernelCtx& ctx, + std::function BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; + + decltype(make_tuple_from_sequence<7>()) tp_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_ diff --git a/oneflow/core/operator/momentum_model_update_op.h b/oneflow/core/operator/momentum_model_update_op.h index 2a7962b17ddb3c8e04a1b7e4b468319e1d90e4ad..7be6e82f85b9daca548fb8c358082b96bc4dd92b 100644 --- a/oneflow/core/operator/momentum_model_update_op.h +++ b/oneflow/core/operator/momentum_model_update_op.h @@ -16,8 +16,6 @@ class MomentumModelUpdateOp final : public NormalModelUpdtOp { private: void MdUpdtVirtualInitFromOpConf() override; - LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } - LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; } // namespace oneflow diff --git a/oneflow/core/operator/naive_model_update_op.h b/oneflow/core/operator/naive_model_update_op.h index 2de800053b3da07fe607e379682c0a9afc9851d3..2582f20aa6921f6f1d6c00c3c68755524c67e0e5 100644 --- a/oneflow/core/operator/naive_model_update_op.h +++ b/oneflow/core/operator/naive_model_update_op.h @@ -13,10 +13,6 @@ class NaiveModelUpdateOp final : public NormalModelUpdtOp { void InferBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const override {} - - private: - LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } - LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; } // namespace oneflow diff --git a/oneflow/core/operator/normal_model_update_op.cpp b/oneflow/core/operator/normal_model_update_op.cpp index e5d4c40d833b850088186390adaa3b41b89b0689..af442739d9eb3dd6adde2c36f639be6f967e15a3 100644 --- a/oneflow/core/operator/normal_model_update_op.cpp +++ b/oneflow/core/operator/normal_model_update_op.cpp @@ -5,9 +5,7 @@ namespace oneflow { void NormalModelUpdtOp::InitFromOpConf() { - FOR_RANGE(int32_t, i, 0, op_conf().normal_mdupdt_conf().in_num()) { - EnrollInputBn("in_" + std::to_string(i), false); - } + EnrollInputBn("model_diff", false); EnrollOutputBn("model", false); MdUpdtVirtualInitFromOpConf(); } @@ -16,6 +14,13 @@ const PbMessage& NormalModelUpdtOp::GetCustomizedConf() const { return op_conf().normal_mdupdt_conf(); } +LogicalBlobId NormalModelUpdtOp::obn2lbi(const std::string& output_bn) const { + const google::protobuf::Descriptor* desc = GetCustomizedConf().GetDescriptor(); + const google::protobuf::FieldDescriptor* fd = desc->FindFieldByName(output_bn); + CHECK(fd); + return GenLogicalBlobId(GetValFromCustomizedConf(output_bn)); +} + REGISTER_OP_CREATOR(OperatorConf::kNormalMdupdtConf, [](const OperatorConf& op_conf) -> Operator* { return NewObj(op_conf.normal_mdupdt_conf().user_conf().normal_mdupdt_case()); }); diff --git a/oneflow/core/operator/normal_model_update_op.h b/oneflow/core/operator/normal_model_update_op.h index c2ea2ba5c5489b987f84559ade7c88cdac5efd18..160c20e1a9277cd9ed70c9700de2c16bf8c57a28 100644 --- a/oneflow/core/operator/normal_model_update_op.h +++ b/oneflow/core/operator/normal_model_update_op.h @@ -18,6 +18,7 @@ class NormalModelUpdtOp : public Operator { virtual void MdUpdtVirtualInitFromOpConf() {} private: + LogicalBlobId obn2lbi(const std::string& output_bn) const override; }; } // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 0f5e93ea1988c6cdb89aeef959c611b781e2320f..586225eca87e25a7ddc19f194d701fb1e2c8d6cb 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -377,7 +377,8 @@ message NormalModelUpdateOpUserConf { message NormalModelUpdateOpConf { required NormalModelUpdateOpUserConf user_conf = 1; - required int32 in_num = 2; + required string model_diff = 2; + required string model = 3; } message AccumulateOpConf { @@ -469,6 +470,10 @@ message MaximumOpConf { required string out = 2; } +message SharedModelDiffAddOpConf { + required int32 in_num = 1; +} + message LocalResponseNormalizationOpConf { required string in = 1; required string out = 2; @@ -671,6 +676,7 @@ message OperatorConf { AccumulateOpConf accumulate_conf = 115; NormalModelUpdateOpConf normal_mdupdt_conf = 116; ModelSaveOpConf model_save_conf = 117; + SharedModelDiffAddOpConf shared_model_diff_add_conf = 118; // domain op TransposeOpConf transpose_conf = 201; diff --git a/oneflow/core/operator/rmsprop_model_update_op.h b/oneflow/core/operator/rmsprop_model_update_op.h index e84982c9b6055706f1cb405d41eb49b25df33ae2..52c80395ae9ee607be1972688ebe078be3469e24 100644 --- a/oneflow/core/operator/rmsprop_model_update_op.h +++ b/oneflow/core/operator/rmsprop_model_update_op.h @@ -16,8 +16,6 @@ class RMSPropModelUpdateOp final : public NormalModelUpdtOp { private: void MdUpdtVirtualInitFromOpConf() override; - LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } - LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } }; } // namespace oneflow diff --git a/oneflow/core/operator/shared_model_diff_add_op.cpp b/oneflow/core/operator/shared_model_diff_add_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d9d33c437be02a1245b97275a0e0ea58b4a501a --- /dev/null +++ b/oneflow/core/operator/shared_model_diff_add_op.cpp @@ -0,0 +1,27 @@ +#include "oneflow/core/operator/shared_model_diff_add_op.h" + +namespace oneflow { + +void SharedModelDiffAddOp::InitFromOpConf() { + CHECK(op_conf().has_shared_model_diff_add_conf()); + FOR_RANGE(int32_t, i, 0, op_conf().shared_model_diff_add_conf().in_num()) { + EnrollInputBn("in_" + std::to_string(i), false); + } + EnrollOutputBn("processed_model_diff", false); +} +const PbMessage& SharedModelDiffAddOp::GetCustomizedConf() const { + return op_conf().shared_model_diff_add_conf(); +} + +void SharedModelDiffAddOp::InferBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(0)); + FOR_RANGE(int32_t, i, 1, input_bns().size()) { + CHECK(*in_0_blob_desc == *GetBlobDesc4BnInOp(input_bns().Get(i))); + } +} + +REGISTER_OP(OperatorConf::kSharedModelDiffAddConf, SharedModelDiffAddOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/shared_model_diff_add_op.h b/oneflow/core/operator/shared_model_diff_add_op.h new file mode 100644 index 0000000000000000000000000000000000000000..61d6b36a49db76f07774f478a1bb4a86d84369b6 --- /dev/null +++ b/oneflow/core/operator/shared_model_diff_add_op.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_ +#define ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_ + +#include "oneflow/core/operator/operator.h" + +namespace oneflow { + +class SharedModelDiffAddOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(SharedModelDiffAddOp); + SharedModelDiffAddOp() = default; + ~SharedModelDiffAddOp() = default; + + void InitFromOpConf() override; + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + const PbMessage& GetCustomizedConf() const override; + + private: + LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); } + LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_