提交 33868c01 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

refine model update conf (#1240)

* refine model update conf

* make todo

* add primary_lr and secondary_lr


Former-commit-id: 5ccd29d7
上级 b3286301
......@@ -11,8 +11,8 @@ data_part_num: 6
train_conf {
batch_size: 6000
total_batch_num: 100
primary_lr: 0.01
model_update_conf {
learning_rate: 0.01
naive_conf {
}
}
......
......@@ -82,6 +82,23 @@ void NormalMdUpdtCompTaskNode::BuildExecGphAndRegst() {
if (Global<JobDesc>::Get()->IsTrain()) {
*(op_conf.mutable_normal_mdupdt_conf()->mutable_user_conf()) =
Global<JobDesc>::Get()->other_conf().train_conf().model_update_conf();
float primary_lr = Global<JobDesc>::Get()->primary_lr();
float secondary_lr = Global<JobDesc>::Get()->secondary_lr();
if (secondary_lr < 0) { secondary_lr = primary_lr; }
if (lbi.blob_name() == "weight") {
op_conf.mutable_normal_mdupdt_conf()->set_learning_rate(primary_lr);
op_conf.mutable_normal_mdupdt_conf()->set_l1(Global<JobDesc>::Get()->weight_l1());
op_conf.mutable_normal_mdupdt_conf()->set_l2(Global<JobDesc>::Get()->weight_l2());
} else if (lbi.blob_name() == "bias") {
op_conf.mutable_normal_mdupdt_conf()->set_learning_rate(secondary_lr);
op_conf.mutable_normal_mdupdt_conf()->set_l1(Global<JobDesc>::Get()->bias_l1());
op_conf.mutable_normal_mdupdt_conf()->set_l2(Global<JobDesc>::Get()->bias_l2());
} else {
op_conf.mutable_normal_mdupdt_conf()->set_learning_rate(primary_lr);
op_conf.mutable_normal_mdupdt_conf()->set_l1(0);
op_conf.mutable_normal_mdupdt_conf()->set_l2(0);
}
}
std::shared_ptr<Operator> model_update_op = ConstructOp(op_conf);
model_update_node = mut_exec_gph().NewNode();
......
......@@ -15,10 +15,14 @@ message TrainConf {
required int32 num_of_batches_in_snapshot = 5;
optional InitializerConf default_initializer_conf = 100;
optional float l1 = 101 [default = 0];
optional float l2 = 102 [default = 0];
optional int64 piece_num_of_print_loss = 104 [default = -1];
optional int64 piece_num_of_print_accuracy = 105 [default = -1];
required float primary_lr = 101;
optional float secondary_lr = 102 [default = -1];
optional float weight_l1 = 103 [default = 0];
optional float bias_l1 = 104 [default = 0];
optional float weight_l2 = 105 [default = 0];
optional float bias_l2 = 106 [default = 0];
optional int64 piece_num_of_print_loss = 107 [default = -1];
optional int64 piece_num_of_print_accuracy = 108 [default = -1];
}
message PredictConf {
......
......@@ -66,14 +66,29 @@ int64_t JobDesc::NumOfPiecesInBatch() const {
CHECK_EQ(BatchSize() % PieceSize(), 0);
return BatchSize() / PieceSize();
}
float JobDesc::L1() const {
float JobDesc::primary_lr() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().l1();
return job_conf_.other().train_conf().primary_lr();
}
float JobDesc::L2() const {
float JobDesc::secondary_lr() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().secondary_lr();
}
float JobDesc::weight_l1() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().weight_l1();
}
float JobDesc::bias_l1() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().bias_l1();
}
float JobDesc::weight_l2() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().weight_l2();
}
float JobDesc::bias_l2() const {
CHECK(IsTrain());
return job_conf_.other().train_conf().l2();
return job_conf_.other().train_conf().bias_l2();
}
int32_t JobDesc::DataPartNum() const { return job_conf_.other().data_part_num(); }
......
......@@ -65,8 +65,12 @@ class JobDesc final {
int32_t PieceNumOfPrintAccuracy() const;
int64_t BatchSize() const;
int64_t NumOfPiecesInBatch() const;
float L1() const;
float L2() const;
float primary_lr() const;
float secondary_lr() const;
float weight_l1() const;
float bias_l1() const;
float weight_l2() const;
float bias_l2() const;
int32_t DataPartNum() const;
private:
......
......@@ -11,7 +11,7 @@ void NormalMdUpdateKernel<device_type, T>::Forward(
int64_t next_model_vid = *reinterpret_cast<int64_t*>(ctx.other);
int64_t cur_batch_num = next_model_vid - 1;
const NormalModelUpdateOpUserConf& conf = this->op_conf().normal_mdupdt_conf().user_conf();
double learning_rate = conf.learning_rate();
float learning_rate = this->op_conf().normal_mdupdt_conf().learning_rate();
if (TriggerWarmup(conf, learning_rate, cur_batch_num)) {
learning_rate = GetWarmupLearningRate(conf.warmup_conf(), learning_rate, cur_batch_num);
} else if (conf.has_learning_rate_decay()) {
......@@ -19,8 +19,8 @@ void NormalMdUpdateKernel<device_type, T>::Forward(
GetDecayedLearningRate(conf.learning_rate_decay(), learning_rate, cur_batch_num);
}
int64_t batch_size = Global<JobDesc>::Get()->BatchSize();
float l1 = Global<JobDesc>::Get()->L1();
float l2 = Global<JobDesc>::Get()->L2();
float l1 = this->op_conf().normal_mdupdt_conf().l1();
float l2 = this->op_conf().normal_mdupdt_conf().l2();
UpdateModel(ctx.device_ctx, batch_size, static_cast<T>(learning_rate), static_cast<T>(l1),
static_cast<T>(l2), next_model_vid, BnInOp2Blob);
}
......
......@@ -365,9 +365,8 @@ message WarmupConf {
}
message NormalModelUpdateOpUserConf {
optional double learning_rate = 1 [default = 0.01];
optional LearningRateDecayConf learning_rate_decay = 2;
optional WarmupConf warmup_conf = 3;
optional LearningRateDecayConf learning_rate_decay = 1;
optional WarmupConf warmup_conf = 2;
oneof normal_mdupdt {
NaiveModelUpdateConf naive_conf = 1000;
MomentumModelUpdateConf momentum_conf = 1001;
......@@ -379,6 +378,9 @@ message NormalModelUpdateOpConf {
required NormalModelUpdateOpUserConf user_conf = 1;
required string model_diff = 2;
required string model = 3;
required float learning_rate = 4;
required float l1 = 5;
required float l2 = 6;
}
message AccumulateOpConf {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册