提交 9f22ecaa 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

separate model for update (#1232)

* make each blob of the packed blob be updated separately in the ModelUpdate

* make blob descs in regst be consistent in bw->md_diff_acc->shared_md_diff_add->md_update->fw

* copy lbi2blob_descs from model

* add shared_model_diff_add kernel

* refine model_update actor and kernel

* rm useless TODO

* add shared_model_diff_add kernel

* refine code


Former-commit-id: 11408363
上级 a4461f07
......@@ -27,9 +27,7 @@ void NormalMdUpdtCompActor::Act() {
Regst* cur_model_regst = GetCurWriteableRegst(model_regst_desc_id_);
cur_model_regst->set_model_version_id(next_model_version_id_);
KernelCtx kernel_ctx = GenDefaultKernelCtx();
std::tuple<int64_t, const Blob*> other_val(next_model_version_id_,
pre_model_regst_->packed_blob());
kernel_ctx.other = &other_val;
kernel_ctx.other = &next_model_version_id_;
pre_model_regst_ = cur_model_regst;
AsyncLaunchKernel(kernel_ctx);
const JobDesc* job_desc = Global<JobDesc>::Get();
......
......@@ -525,15 +525,11 @@ void LogicalGraph::AddReduceScatterAddGatherNodes(LogicalNode* src, LogicalNode*
void LogicalGraph::SetupNormalMdUpdtOp() {
ForEachLogicalNode<NormalMdUpdtLogicalNode>([](NormalMdUpdtLogicalNode* node) {
if (node->in_edges().size() < 1) { return; }
// Add shared_model_diff_add_op
OperatorConf op_conf;
op_conf.set_name("md_update_" + NewUniqueId());
op_conf.set_name("md_diff_add_" + NewUniqueId());
op_conf.set_device_type(node->parallel_desc()->device_type());
NormalModelUpdateOpConf* mdupdt_conf = op_conf.mutable_normal_mdupdt_conf();
const JobDesc* job_desc = Global<JobDesc>::Get();
if (Global<JobDesc>::Get()->IsTrain()) {
*(mdupdt_conf->mutable_user_conf()) = job_desc->other_conf().train_conf().model_update_conf();
}
mdupdt_conf->set_in_num(node->in_edges().size());
op_conf.mutable_shared_model_diff_add_conf()->set_in_num(node->in_edges().size());
node->mut_op_vec() = {ConstructOp(op_conf)};
});
}
......
......@@ -22,6 +22,7 @@ void NormalMdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
int32_t max_model_regst = 1;
auto model_regst = ProduceRegst("model", false, 1, max_model_regst);
auto const_model_regst = ProduceRegst("const_model", false, 1, 1);
ProduceRegst("processed_model_diff", false, 1, 1);
ProduceRegst("data_tmp", false, 1, 1);
related_init_model_task_id_ = -1;
for (TaskEdge* out_edge : out_edges()) {
......@@ -56,18 +57,52 @@ bool NormalMdUpdtCompTaskNode::IsReadyForBuild() {
void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
if (!IsTrainable()) { return; }
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = logical_node()->SoleOp();
ExecNode* shared_model_diff_add_node = mut_exec_gph().NewNode();
shared_model_diff_add_node->mut_op() = logical_node()->SoleOp();
size_t ibn_idx = 0;
for (const auto& pair : consumed_regsts()) {
node->BindBnWithRegst(node->op()->input_bns().Get(ibn_idx++), pair.second.front());
shared_model_diff_add_node->BindBnWithRegst(
shared_model_diff_add_node->op()->input_bns().Get(ibn_idx++), pair.second.front());
}
node->BindBnWithRegst(node->op()->SoleObn(), GetProducedRegst("model"));
node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
node->InferBlobDescs(nullptr);
std::shared_ptr<RegstDesc> processed_model_diff_regst = GetProducedRegst("processed_model_diff");
shared_model_diff_add_node->BindBnWithRegst(logical_node()->SoleOp()->SoleObn(),
processed_model_diff_regst);
// "model" regst is already bound with lbis and locked by the corresponding
// NormalForwardCompTaskNode
processed_model_diff_regst->CopyBlobDescFrom(GetProducedRegst("model").get());
ExecNode* model_update_node = nullptr;
ExecEdge* exec_edge = nullptr;
processed_model_diff_regst->ForEachLbi([&](const LogicalBlobId& lbi) {
OperatorConf op_conf;
op_conf.set_name("md_update_" + lbi.op_name() + "_" + lbi.blob_name());
op_conf.set_device_type(logical_node()->parallel_desc()->device_type());
op_conf.mutable_normal_mdupdt_conf()->set_model_diff(lbi.op_name() + '/' + lbi.blob_name());
op_conf.mutable_normal_mdupdt_conf()->set_model(lbi.op_name() + '/' + lbi.blob_name());
if (Global<JobDesc>::Get()->IsTrain()) {
*(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) =
Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf();
}
std::shared_ptr<Operator> model_update_op = ConstructOp(op_conf);
model_update_node = mut_exec_gph().NewNode();
model_update_node->mut_op() = model_update_op;
exec_edge = mut_exec_gph().NewEdge();
exec_edge->set_lbi(lbi);
exec_edge->mut_src_bn() = lbi.blob_name();
exec_edge->mut_dst_bn() = model_update_op->SoleIbn();
Connect(shared_model_diff_add_node, exec_edge, model_update_node);
model_update_node->BindBnWithRegst(model_update_op->SoleIbn(), processed_model_diff_regst);
model_update_node->BindBnWithRegst(model_update_op->SoleObn(), GetProducedRegst("model"));
model_update_node->AddBnToRegstAndBindIt(&Operator::data_tmp_bns, GetProducedRegst("data_tmp"));
});
mut_exec_gph().TopoForEachNode([this](ExecNode* node) { node->InferBlobDescs(parallel_ctx()); });
}
void NormalMdUpdtCompTaskNode::LockRegsts() { GetProducedRegst("data_tmp")->Lock(); }
void NormalMdUpdtCompTaskNode::LockRegsts() {
GetProducedRegst("processed_model_diff")->Lock();
GetProducedRegst("data_tmp")->Lock();
}
void NormalMdUpdtCompTaskNode::ToProto(TaskProto* task_proto) {
CompTaskNode::ToProto(task_proto);
......
......@@ -5,9 +5,9 @@ namespace oneflow {
template<DeviceType device_type, typename T>
void MomentumMdUpdateKernel<device_type, T>::UpdateModel(
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob,
const Blob* model_diff_blob, int64_t next_model_vid,
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* model_diff_blob = BnInOp2Blob("model_diff");
Blob* model_blob = BnInOp2Blob("model");
Blob* momentum_blob = BnInOp2Blob("momentum");
float beta = this->op_conf().normal_mdupdt_conf().user_conf().momentum_conf().beta();
......@@ -15,19 +15,18 @@ void MomentumMdUpdateKernel<device_type, T>::UpdateModel(
MomentumMdUpdateKernelUtil<device_type, T>::UpdateModel(
ctx, model_blob->shape().elem_cnt(), batch_size, static_cast<T>(beta), learning_rate, l1, l2,
model_diff_blob->dptr<T>(), pre_model_blob->dptr<T>(), momentum_blob->mut_dptr<T>(),
model_blob->mut_dptr<T>());
model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(), momentum_blob->mut_dptr<T>());
}
template<typename T>
class MomentumMdUpdateKernelUtil<DeviceType::kCPU, T> final {
public:
static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T beta, T learning_rate, T l1,
T l2, const T* model_diff, const T* pre_model, T* momentum, T* model) {
T l2, const T* model_diff, T* model, T* momentum) {
for (int64_t i = 0; i != n; ++i) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
momentum[i] = beta * momentum[i] - learning_rate * reg_diff;
model[i] = pre_model[i] + momentum[i];
model[i] = model[i] + momentum[i];
}
}
};
......
......@@ -8,11 +8,11 @@ namespace {
template<typename T>
__global__ void UpdateModelGpu(int64_t n, int64_t batch_size, T beta, T learning_rate, T l1, T l2,
const T* model_diff, const T* pre_model, T* momentum, T* model) {
const T* model_diff, T* model, T* momentum) {
CUDA_1D_KERNEL_LOOP(i, n) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
momentum[i] = beta * momentum[i] - learning_rate * reg_diff;
model[i] = pre_model[i] + momentum[i];
model[i] = model[i] + momentum[i];
}
}
......@@ -22,10 +22,9 @@ template<typename T>
class MomentumMdUpdateKernelUtil<DeviceType::kGPU, T> final {
public:
static void UpdateModel(DeviceCtx* ctx, int64_t n, int64_t batch_size, T beta, T learning_rate,
const T l1, const T l2, const T* model_diff, const T* pre_model,
T* momentum, T* model) {
const T l1, const T l2, const T* model_diff, T* model, T* momentum) {
UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_size, beta, learning_rate, l1, l2, model_diff, pre_model, momentum, model);
n, batch_size, beta, learning_rate, l1, l2, model_diff, model, momentum);
}
};
......
......@@ -14,7 +14,7 @@ class MomentumMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
private:
void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid,
int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
......@@ -22,7 +22,7 @@ template<DeviceType device_type, typename T>
class MomentumMdUpdateKernelUtil final {
public:
static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T beta, T learning_rate, T l1,
T l2, const T* model_diff, const T* pre_model, T* momentum, T* model);
T l2, const T* model_diff, T* model, T* momentum);
};
DECLARE_MDUPDT_KERNEL_CREATOR(Momentum);
......
......@@ -5,25 +5,24 @@ namespace oneflow {
template<DeviceType device_type, typename T>
void NaiveMdUpdateKernel<device_type, T>::UpdateModel(
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob,
const Blob* model_diff_blob, int64_t next_model_vid,
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* model_diff_blob = BnInOp2Blob("model_diff");
Blob* model_blob = BnInOp2Blob("model");
if (pre_model_blob != model_blob) { model_blob->CopyDataContentFrom(ctx, pre_model_blob); }
// model = model - alpha * model_diff
NaiveMdUpdateKernelUtil<device_type, T>::UpdateModel(
ctx, model_blob->shape().elem_cnt(), batch_size, learning_rate, l1, l2,
model_diff_blob->dptr<T>(), pre_model_blob->dptr<T>(), model_blob->mut_dptr<T>());
model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>());
}
template<typename T>
class NaiveMdUpdateKernelUtil<DeviceType::kCPU, T> final {
public:
static void UpdateModel(DeviceCtx*, const int64_t n, int64_t batch_size, T learning_rate, T l1,
T l2, const T* model_diff, const T* pre_model, T* model) {
T l2, const T* model_diff, T* model) {
for (int64_t i = 0; i != n; ++i) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
model[i] = pre_model[i] - learning_rate * reg_diff;
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
model[i] = model[i] - learning_rate * reg_diff;
}
}
};
......
......@@ -8,10 +8,10 @@ namespace {
template<typename T>
__global__ void UpdateModelGpu(const int64_t n, int64_t batch_size, T learning_rate, T l1, T l2,
const T* model_diff, const T* pre_model, T* model) {
const T* model_diff, T* model) {
CUDA_1D_KERNEL_LOOP(i, n) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
model[i] = pre_model[i] - learning_rate * reg_diff;
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
model[i] = model[i] - learning_rate * reg_diff;
}
}
......@@ -21,9 +21,9 @@ template<typename T>
class NaiveMdUpdateKernelUtil<DeviceType::kGPU, T> final {
public:
static void UpdateModel(DeviceCtx* ctx, const int64_t n, int64_t batch_size, T learning_rate,
T l1, T l2, const T* model_diff, const T* pre_model, T* model) {
T l1, T l2, const T* model_diff, T* model) {
UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_size, learning_rate, l1, l2, model_diff, pre_model, model);
n, batch_size, learning_rate, l1, l2, model_diff, model);
}
};
......
......@@ -14,7 +14,7 @@ class NaiveMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T> {
private:
void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid,
int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
......@@ -22,7 +22,7 @@ template<DeviceType device_type, typename T>
class NaiveMdUpdateKernelUtil final {
public:
static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T l1, T l2,
const T* model_diff, const T* pre_model, T* model);
const T* model_diff, T* model);
};
DECLARE_MDUPDT_KERNEL_CREATOR(Naive);
......
......@@ -8,8 +8,7 @@ namespace oneflow {
template<DeviceType device_type, typename T>
void NormalMdUpdateKernel<device_type, T>::Forward(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
auto tpl = reinterpret_cast<std::tuple<int64_t, const Blob*>*>(ctx.other);
int64_t next_model_vid = std::get<0>(*tpl);
int64_t next_model_vid = *reinterpret_cast<int64_t*>(ctx.other);
int64_t cur_batch_num = next_model_vid - 1;
const NormalModelUpdateOpUserConf& conf = this->op_conf().normal_mdupdt_conf().user_conf();
double learning_rate = conf.learning_rate();
......@@ -19,19 +18,11 @@ void NormalMdUpdateKernel<device_type, T>::Forward(
learning_rate =
GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num);
}
const OpAttribute& op_attribute = this->kernel_conf().op_attribute();
Blob* in_0 = BnInOp2Blob(op_attribute.input_bns(0));
FOR_RANGE(size_t, i, 1, op_attribute.input_bns().size()) {
Blob* in_i = BnInOp2Blob(op_attribute.input_bns(i));
KernelUtil<device_type, T>::Axpy(ctx.device_ctx, in_0->shape().elem_cnt(), static_cast<T>(1.0),
in_i->dptr<T>(), 1, in_0->mut_dptr<T>(), 1);
}
int64_t batch_size = Global<JobDesc>::Get()->BatchSize();
float l1 = Global<JobDesc>::Get()->L1();
float l2 = Global<JobDesc>::Get()->L2();
UpdateModel(ctx.device_ctx, batch_size, static_cast<T>(learning_rate), static_cast<T>(l1),
static_cast<T>(l2), std::get<1>(*tpl), in_0, next_model_vid, BnInOp2Blob);
static_cast<T>(l2), next_model_vid, BnInOp2Blob);
}
#define INSTANTIATE_KERNEL(device_type, data_type_pair) \
......
......@@ -17,7 +17,6 @@ class NormalMdUpdateKernel : public KernelIf<device_type> {
protected:
NormalMdUpdateKernel() = default;
virtual void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
const Blob* pre_model_blob, const Blob* model_diff_blob,
int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0;
};
......
......@@ -5,9 +5,9 @@ namespace oneflow {
template<DeviceType device_type, typename T>
void RMSPropMdUpdateKernel<device_type, T>::UpdateModel(
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, const Blob* pre_model_blob,
const Blob* model_diff_blob, int64_t next_model_vid,
DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2, int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* model_diff_blob = BnInOp2Blob("model_diff");
Blob* model_blob = BnInOp2Blob("model");
Blob* mean_square_blob = BnInOp2Blob("mean_square");
const RMSPropModelUpdateConf& conf =
......@@ -17,20 +17,19 @@ void RMSPropMdUpdateKernel<device_type, T>::UpdateModel(
RMSPropMdUpdateKernelUtil<device_type, T>::UpdateModel(
ctx, model_blob->shape().elem_cnt(), batch_size, learning_rate, static_cast<T>(decay_rate),
static_cast<T>(conf.epsilon()), l1, l2, pre_model_blob->dptr<T>(), model_blob->mut_dptr<T>(),
mean_square_blob->mut_dptr<T>(), model_diff_blob->dptr<T>());
static_cast<T>(conf.epsilon()), l1, l2, model_diff_blob->dptr<T>(), model_blob->mut_dptr<T>(),
mean_square_blob->mut_dptr<T>());
}
template<typename T>
class RMSPropMdUpdateKernelUtil<DeviceType::kCPU, T> final {
public:
static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T decay_rate,
T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square,
const T* model_diff) {
T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square) {
for (int64_t i = 0; i < n; ++i) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i];
model[i] = pre_model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
}
}
};
......
......@@ -8,12 +8,12 @@ namespace {
template<typename T>
__global__ void UpdateModelGpu(int64_t n, int64_t batch_size, T learning_rate, T decay_rate,
T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square,
const T* model_diff) {
T epsilon, T l1, T l2, const T* model_diff, T* model,
T* mean_square) {
CUDA_1D_KERNEL_LOOP(i, n) {
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, pre_model[i]);
T reg_diff = RegularizeDiff(model_diff[i], batch_size, l1, l2, model[i]);
mean_square[i] = (1 - decay_rate) * reg_diff * reg_diff + decay_rate * mean_square[i];
model[i] = pre_model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
model[i] = model[i] - learning_rate * reg_diff / std::sqrt(mean_square[i] + epsilon);
}
}
......@@ -23,11 +23,10 @@ template<typename T>
class RMSPropMdUpdateKernelUtil<DeviceType::kGPU, T> final {
public:
static void UpdateModel(DeviceCtx* ctx, int64_t n, int64_t batch_size, T learning_rate,
T decay_rate, T epsilon, T l1, T l2, const T* pre_model, T* model,
T* mean_square, const T* model_diff) {
T decay_rate, T epsilon, T l1, T l2, const T* model_diff, T* model,
T* mean_square) {
UpdateModelGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, batch_size, learning_rate, decay_rate, epsilon, l1, l2, pre_model, model, mean_square,
model_diff);
n, batch_size, learning_rate, decay_rate, epsilon, l1, l2, model_diff, model, mean_square);
}
};
......
......@@ -14,7 +14,7 @@ class RMSPropMdUpdateKernel final : public NormalMdUpdateKernel<device_type, T>
private:
void UpdateModel(DeviceCtx* ctx, int64_t batch_size, T learning_rate, T l1, T l2,
const Blob* pre_model_blob, const Blob* model_diff_blob, int64_t next_model_vid,
int64_t next_model_vid,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
......@@ -22,10 +22,9 @@ template<DeviceType device_type, typename T>
class RMSPropMdUpdateKernelUtil final {
public:
// mean_square = (1 - decay_rate) * model_diff ^ 2 + decay_rate * mean_square
// model = pre_model - learning_rate * model_diff / sqrt(mean_square + epsilon)
// model = model - learning_rate * model_diff / sqrt(mean_square + epsilon)
static void UpdateModel(DeviceCtx*, int64_t n, int64_t batch_size, T learning_rate, T decay_rate,
T epsilon, T l1, T l2, const T* pre_model, T* model, T* mean_square,
const T* model_diff);
T epsilon, T l1, T l2, const T* model_diff, T* model, T* mean_square);
};
DECLARE_MDUPDT_KERNEL_CREATOR(RMSProp);
......
#include "oneflow/core/kernel/shared_model_diff_add_kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/kernel_common.hpp"
namespace oneflow {
template<DeviceType device_type, typename T>
void SharedModelDiffAddKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const PbRpf<std::string>& ibns = this->op_attribute().input_bns();
size_t in_num = ibns.size();
if (in_num == 0) return;
Blob* out_blob = BnInOp2Blob(this->op_attribute().output_bns(0));
auto in_blob = [&](int32_t idx) { return BnInOp2Blob(this->op_attribute().input_bns(idx)); };
static const int kWidth = 8;
int r = in_num % kWidth;
if (r) {
tuple_switch(r, tp_,
AdditionFunction<true, device_type, T, decltype(this)>{
out_blob, std::move(BnInOp2Blob), ctx.device_ctx, 0, this});
}
for (; r < in_num; r += kWidth) {
Addition<device_type, T>(ctx.device_ctx, out_blob, out_blob, in_blob(r), in_blob(r + 1),
in_blob(r + 2), in_blob(r + 3), in_blob(r + 4), in_blob(r + 5),
in_blob(r + 6), in_blob(r + 7));
}
}
template<DeviceType device_type, typename T>
const PbMessage& SharedModelDiffAddKernel<device_type, T>::GetCustomizedOpConf() const {
return this->op_conf().shared_model_diff_add_conf();
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kSharedModelDiffAddConf, SharedModelDiffAddKernel,
ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class SharedModelDiffAddKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(SharedModelDiffAddKernel);
SharedModelDiffAddKernel() = default;
~SharedModelDiffAddKernel() = default;
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
const PbMessage& GetCustomizedOpConf() const override;
decltype(make_tuple_from_sequence<7>()) tp_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_SHARED_MODEL_DIFF_ADD_KERNEL_H_
......@@ -16,8 +16,6 @@ class MomentumModelUpdateOp final : public NormalModelUpdtOp {
private:
void MdUpdtVirtualInitFromOpConf() override;
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
};
} // namespace oneflow
......
......@@ -13,10 +13,6 @@ class NaiveModelUpdateOp final : public NormalModelUpdtOp {
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override {}
private:
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
};
} // namespace oneflow
......
......@@ -5,9 +5,7 @@
namespace oneflow {
void NormalModelUpdtOp::InitFromOpConf() {
FOR_RANGE(int32_t, i, 0, op_conf().normal_mdupdt_conf().in_num()) {
EnrollInputBn("in_" + std::to_string(i), false);
}
EnrollInputBn("model_diff", false);
EnrollOutputBn("model", false);
MdUpdtVirtualInitFromOpConf();
}
......@@ -16,6 +14,13 @@ const PbMessage& NormalModelUpdtOp::GetCustomizedConf() const {
return op_conf().normal_mdupdt_conf();
}
LogicalBlobId NormalModelUpdtOp::obn2lbi(const std::string& output_bn) const {
const google::protobuf::Descriptor* desc = GetCustomizedConf().GetDescriptor();
const google::protobuf::FieldDescriptor* fd = desc->FindFieldByName(output_bn);
CHECK(fd);
return GenLogicalBlobId(GetValFromCustomizedConf<std::string>(output_bn));
}
REGISTER_OP_CREATOR(OperatorConf::kNormalMdupdtConf, [](const OperatorConf& op_conf) -> Operator* {
return NewObj<NormalModelUpdtOp>(op_conf.normal_mdupdt_conf().user_conf().normal_mdupdt_case());
});
......
......@@ -18,6 +18,7 @@ class NormalModelUpdtOp : public Operator {
virtual void MdUpdtVirtualInitFromOpConf() {}
private:
LogicalBlobId obn2lbi(const std::string& output_bn) const override;
};
} // namespace oneflow
......
......@@ -377,7 +377,8 @@ message NormalModelUpdateOpUserConf {
message NormalModelUpdateOpConf {
required NormalModelUpdateOpUserConf user_conf = 1;
required int32 in_num = 2;
required string model_diff = 2;
required string model = 3;
}
message AccumulateOpConf {
......@@ -469,6 +470,10 @@ message MaximumOpConf {
required string out = 2;
}
message SharedModelDiffAddOpConf {
required int32 in_num = 1;
}
message LocalResponseNormalizationOpConf {
required string in = 1;
required string out = 2;
......@@ -671,6 +676,7 @@ message OperatorConf {
AccumulateOpConf accumulate_conf = 115;
NormalModelUpdateOpConf normal_mdupdt_conf = 116;
ModelSaveOpConf model_save_conf = 117;
SharedModelDiffAddOpConf shared_model_diff_add_conf = 118;
// domain op
TransposeOpConf transpose_conf = 201;
......
......@@ -16,8 +16,6 @@ class RMSPropModelUpdateOp final : public NormalModelUpdtOp {
private:
void MdUpdtVirtualInitFromOpConf() override;
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
};
} // namespace oneflow
......
#include "oneflow/core/operator/shared_model_diff_add_op.h"
namespace oneflow {
void SharedModelDiffAddOp::InitFromOpConf() {
CHECK(op_conf().has_shared_model_diff_add_conf());
FOR_RANGE(int32_t, i, 0, op_conf().shared_model_diff_add_conf().in_num()) {
EnrollInputBn("in_" + std::to_string(i), false);
}
EnrollOutputBn("processed_model_diff", false);
}
const PbMessage& SharedModelDiffAddOp::GetCustomizedConf() const {
return op_conf().shared_model_diff_add_conf();
}
void SharedModelDiffAddOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp(input_bns().Get(0));
FOR_RANGE(int32_t, i, 1, input_bns().size()) {
CHECK(*in_0_blob_desc == *GetBlobDesc4BnInOp(input_bns().Get(i)));
}
}
REGISTER_OP(OperatorConf::kSharedModelDiffAddConf, SharedModelDiffAddOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_
#define ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class SharedModelDiffAddOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(SharedModelDiffAddOp);
SharedModelDiffAddOp() = default;
~SharedModelDiffAddOp() = default;
void InitFromOpConf() override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
const PbMessage& GetCustomizedConf() const override;
private:
LogicalBlobId ibn2lbi(const std::string& input_bn) const override { return GenPackedLbi(); }
LogicalBlobId obn2lbi(const std::string& output_bn) const override { return GenPackedLbi(); }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_SHARED_MODEL_DIFF_ADD_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册