未验证 提交 c736c538 编写于 作者: G guo ran 提交者: GitHub

adam_bias_correction_learning_rate (#3763)

* adam_bias_correction_learning_rate

* refine

* refine

* fix
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
上级 bce7d519
......@@ -18,6 +18,39 @@ limitations under the License.
namespace oneflow {
struct AdamBiasCorrectionLearningRateCacheKey {
float beta1;
float beta2;
std::string lr_lbn;
std::string step_lbn;
ParallelConf parallel_conf;
};
bool operator==(const AdamBiasCorrectionLearningRateCacheKey& lhs,
const AdamBiasCorrectionLearningRateCacheKey& rhs) {
return (lhs.beta1 == rhs.beta1) && (lhs.beta2 == rhs.beta2) && (lhs.lr_lbn == rhs.lr_lbn)
&& (lhs.step_lbn == rhs.step_lbn) && (lhs.parallel_conf == rhs.parallel_conf);
}
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::AdamBiasCorrectionLearningRateCacheKey> {
size_t operator()(const oneflow::AdamBiasCorrectionLearningRateCacheKey& key) const {
const auto& str_hash = std::hash<std::string>();
const auto& float_hash = std::hash<float>();
const auto& parallel_conf_hash = std::hash<oneflow::ParallelConf>();
return float_hash(key.beta1) ^ float_hash(key.beta2) ^ str_hash(key.lr_lbn)
^ str_hash(key.step_lbn) ^ parallel_conf_hash(key.parallel_conf);
}
};
} // namespace std
namespace oneflow {
namespace {
std::string GenVariableOutputLbn(const OperatorConf& op_conf) {
......@@ -44,8 +77,37 @@ void SetScalarShapeAndSbpConf(OperatorConf* op_conf) {
CHECK_NE(op_conf->name(), std::string(""));
}
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
class AdamBiasCorrectionLearningRateState final : public JobPassState {
public:
AdamBiasCorrectionLearningRateState() {}
~AdamBiasCorrectionLearningRateState() override = default;
std::string GetLbn(float beta1, float beta2, std::string lr_lbn, std::string step_lbn,
ParallelConf parallel_conf,
std::function<std::string()> AddAdamBiasCorrectionLearningRateOp) {
AdamBiasCorrectionLearningRateCacheKey cache_key;
cache_key.beta1 = beta1;
cache_key.beta2 = beta2;
cache_key.lr_lbn = lr_lbn;
cache_key.step_lbn = step_lbn;
cache_key.parallel_conf = parallel_conf;
const auto& iter = key2lbn_.find(cache_key);
if (iter != key2lbn_.end()) {
return iter->second;
} else {
std::string lbn = AddAdamBiasCorrectionLearningRateOp();
key2lbn_.emplace(cache_key, lbn);
return lbn;
}
}
private:
HashMap<AdamBiasCorrectionLearningRateCacheKey, std::string> key2lbn_;
};
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
const auto& train_conf = job_builder->job().job_conf().train_conf();
const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf();
......@@ -73,31 +135,60 @@ void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_
} else {
UNIMPLEMENTED();
}
const std::string& train_step_lbn = train_conf.train_step_lbn();
const std::string& primary_lr_lbn = train_conf.primary_lr_lbn();
std::string lr_lbn;
if (do_bias_correction) {
const std::string& job_pass_state_key = "adam_bias_correction_learning_rate";
const bool has_state =
CHECK_JUST(ctx->HasState<AdamBiasCorrectionLearningRateState>(job_pass_state_key));
if (!has_state) {
ctx->ResetState(job_pass_state_key, std::make_unique<AdamBiasCorrectionLearningRateState>());
}
auto* state =
CHECK_JUST(ctx->MutableState<AdamBiasCorrectionLearningRateState>(job_pass_state_key));
ParallelConf bias_correction_parallel_conf;
const auto& lr_parallel_conf = job_builder->ParallelConf4Lbi(GenLogicalBlobId(primary_lr_lbn));
const auto& train_step_parallel_conf =
job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn));
if (lr_parallel_conf == train_step_parallel_conf) {
bias_correction_parallel_conf = lr_parallel_conf;
} else {
bias_correction_parallel_conf = parallel_conf;
}
auto AddAdamBiasCorrectionLearningRateOp = [&]() -> std::string {
user_op::UserOpConfWrapperBuilder op_builder(op.op_name()
+ "_adam_bias_correction_learning_rate");
const auto adam_bias_correction_learning_rate_op =
op_builder.OpTypeName("adam_bias_correction_learning_rate")
.Input("learning_rate", primary_lr_lbn)
.Input("train_step", train_step_lbn)
.Attr<float>("beta1", beta1)
.Attr<float>("beta2", beta2)
.Output("out")
.ScopeSymbolId(op.op_conf().scope_symbol_id())
.Build();
job_builder->AddOps(bias_correction_parallel_conf,
{adam_bias_correction_learning_rate_op.op_conf()});
return adam_bias_correction_learning_rate_op.output("out", 0);
};
lr_lbn = state->GetLbn(beta1, beta2, primary_lr_lbn, train_step_lbn,
bias_correction_parallel_conf, AddAdamBiasCorrectionLearningRateOp);
} else {
lr_lbn = primary_lr_lbn;
}
adam_update_op_builder.OpTypeName("adam_update")
.Input("model", GenLogicalBlobName(op.BnInOp2Lbi("out")))
.Input("model_diff", GenLogicalBlobName(diff_lbi_of_var_out))
.Input("learning_rate", train_conf.primary_lr_lbn())
.Input("learning_rate", lr_lbn)
.Input("m", GenVariableOutputLbn(m_var))
.Input("v", GenVariableOutputLbn(v_var))
.Attr<float>("beta1", beta1)
.Attr<float>("beta2", beta2)
.Attr<float>("epsilon", epsilon)
.Attr<bool>("do_bias_correction", do_bias_correction)
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(model_update_conf, op))
.ScopeSymbolId(op.op_conf().scope_symbol_id());
if (do_bias_correction) {
OperatorConf beta1_t_var;
OperatorConf beta2_t_var;
beta1_t_var = GenerateAdamHelperVariableOpConf(op, "beta1_t", beta1);
SetScalarShapeAndSbpConf(&beta1_t_var);
beta2_t_var = GenerateAdamHelperVariableOpConf(op, "beta2_t", beta2);
SetScalarShapeAndSbpConf(&beta2_t_var);
job_builder->AddOps(parallel_conf, {beta1_t_var, beta2_t_var});
adam_update_op_builder.Input("beta1_t", GenVariableOutputLbn(beta1_t_var));
adam_update_op_builder.Input("beta2_t", GenVariableOutputLbn(beta2_t_var));
}
const auto adam_update_op = adam_update_op_builder.Build();
job_builder->AddOps(parallel_conf, {adam_update_op.op_conf()});
}
......
......@@ -162,17 +162,11 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
fused_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
} else if (user_op_conf.op_type_name() == "adam_update") {
const bool do_bias_correction = user_op_conf.attr<bool>("do_bias_correction");
fused_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"))
.Attr<bool>("do_bias_correction", do_bias_correction);
if (do_bias_correction) {
fused_op_builder.Input("beta1_t", user_op_conf.input("beta1_t", 0))
.Input("beta2_t", user_op_conf.input("beta2_t", 0));
}
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
} else {
UNIMPLEMENTED();
}
......
......@@ -92,14 +92,7 @@ class GenerateBackwardAndOptimizerOpConfs final : public JobPass {
bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }
Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;
Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {
if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);
JobBuilder job_builder(job);
return Apply(op_graph, &job_builder);
}
Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;
};
void FilterModelLbi2DiffLbi(const OpGraph& op_graph,
......@@ -115,26 +108,28 @@ void FilterModelLbi2DiffLbi(const OpGraph& op_graph,
}
}
Maybe<void> GenerateBackwardAndOptimizerOpConfs::Apply(const OpGraph& op_graph,
JobBuilder* job_builder) const {
Maybe<void> GenerateBackwardAndOptimizerOpConfs::Apply(Job* job, JobPassCtx* ctx) const {
if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);
JobBuilder job_builder(job);
LogicalBlobId total_loss_instance_num;
HashMap<LogicalBlobId, LogicalBlobId> lbi2diff_lbi;
JUST(AutoGrad(op_graph, job_builder, &lbi2diff_lbi));
JUST(AutoGrad(op_graph, &job_builder, &lbi2diff_lbi));
HashMap<LogicalBlobId, LogicalBlobId> model_lbi2model_diff_lbi;
FilterModelLbi2DiffLbi(op_graph, lbi2diff_lbi, &model_lbi2model_diff_lbi);
AddDiffStaticShapeCast(op_graph, job_builder, &model_lbi2model_diff_lbi);
AddDiffParallelCast(op_graph, job_builder, &model_lbi2model_diff_lbi);
JUST(ScaleModelDiffByLossInstanceNum(op_graph, job_builder, &model_lbi2model_diff_lbi));
ScaleModelDiffByLossScale(op_graph, job_builder, &model_lbi2model_diff_lbi);
AddDiffStaticShapeCast(op_graph, &job_builder, &model_lbi2model_diff_lbi);
AddDiffParallelCast(op_graph, &job_builder, &model_lbi2model_diff_lbi);
JUST(ScaleModelDiffByLossInstanceNum(op_graph, &job_builder, &model_lbi2model_diff_lbi));
ScaleModelDiffByLossScale(op_graph, &job_builder, &model_lbi2model_diff_lbi);
const NormalModelUpdateOpUserConf& model_update_conf =
job_builder->job().job_conf().train_conf().model_update_conf();
RegularizeGradient(op_graph, job_builder, &model_lbi2model_diff_lbi);
job->job_conf().train_conf().model_update_conf();
RegularizeGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi);
if (model_update_conf.has_clip_conf()) {
ClipGradient(op_graph, job_builder, &model_lbi2model_diff_lbi, model_update_conf.clip_conf());
ClipGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi, model_update_conf.clip_conf());
}
AddOptimizerOpConf(op_graph, job_builder, model_lbi2model_diff_lbi);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder);
UpdateOpSbpSignatureHint(op_graph, job_builder);
AddOptimizerOpConf(ctx, op_graph, &job_builder, model_lbi2model_diff_lbi);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, &job_builder);
UpdateOpSbpSignatureHint(op_graph, &job_builder);
return Maybe<void>::Ok();
}
......
......@@ -115,17 +115,11 @@ Maybe<void> IndexedSlicesOptimizerRewritePass::Apply(const OpGraph& op_graph,
indexed_slices_op_builder.Input("momentum", user_op_conf.input("momentum", 0))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
} else if (user_op_conf.op_type_name() == "adam_update") {
const bool do_bias_correction = user_op_conf.attr<bool>("do_bias_correction");
indexed_slices_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"))
.Attr<bool>("do_bias_correction", do_bias_correction);
if (do_bias_correction) {
indexed_slices_op_builder.Input("beta1_t", user_op_conf.input("beta1_t", 0))
.Input("beta2_t", user_op_conf.input("beta2_t", 0));
}
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
} else {
return;
}
......
......@@ -69,11 +69,9 @@ class JobPassCtx {
}
template<typename T>
Maybe<const T&> HasState(const std::string& key) const {
Maybe<bool> HasState(const std::string& key) const {
const auto& iter = key2state_.find(key);
CHECK_OR_RETURN(iter != key2state_.end());
const T* ptr = dynamic_cast<T*>(iter->second.get());
return ptr != nullptr;
return (iter != key2state_.end());
}
Maybe<void> ResetState(const std::string& key, std::unique_ptr<JobPassState>&& state) {
......
......@@ -43,8 +43,9 @@ void SetScalarShapeAndSbpConf(OperatorConf* op_conf) {
CHECK_NE(op_conf->name(), std::string(""));
}
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
const auto& train_conf = job_builder->job().job_conf().train_conf();
const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf();
OperatorConf m_var = GenerateLAMBHelperVariableOpConf(op, "m", 0.f);
......
......@@ -19,8 +19,9 @@ namespace oneflow {
namespace {
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
OperatorConf momentum_var(op.op_conf());
InitializerConf constant_initializer;
constant_initializer.mutable_constant_conf()->set_value(0.f);
......
......@@ -21,8 +21,9 @@ namespace oneflow {
namespace {
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
const auto& train_conf = job_builder->job().job_conf().train_conf();
const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf();
......
......@@ -18,22 +18,23 @@ limitations under the License.
namespace oneflow {
void GenerateOptimizerOpConfWrapperStruct::Call(const VariableOp& var_op,
void GenerateOptimizerOpConfWrapperStruct::Call(JobPassCtx* ctx, const VariableOp& var_op,
const ParallelConf& parallel_conf,
JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) const {
(*func_)(var_op, parallel_conf, job_builder, diff_lbi_of_var_out);
(*func_)(ctx, var_op, parallel_conf, job_builder, diff_lbi_of_var_out);
}
void GenerateOptimizerOpConfIf(const VariableOp& var_op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConfIf(JobPassCtx* ctx, const VariableOp& var_op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
const auto& train_conf = GlobalJobDesc().job_conf().train_conf();
auto optimizer_case = train_conf.model_update_conf().normal_mdupdt_case();
auto* obj = NewObj<int32_t, GenerateOptimizerOpConfWrapperStruct>(optimizer_case);
obj->Call(var_op, parallel_conf, job_builder, diff_lbi_of_var_out);
obj->Call(ctx, var_op, parallel_conf, job_builder, diff_lbi_of_var_out);
}
void AddOptimizerOpConf(const OpGraph& op_graph, JobBuilder* job_builder,
void AddOptimizerOpConf(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,
const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi) {
op_graph.ForEachNode([&](OpNode* op_node) {
const VariableOp* var_op = dynamic_cast<const VariableOp*>(&op_node->op());
......@@ -42,7 +43,7 @@ void AddOptimizerOpConf(const OpGraph& op_graph, JobBuilder* job_builder,
LogicalBlobId diff_lbi_of_var_out = lbi2diff_lbi.at(var_op->BnInOp2Lbi(var_op->SoleObn()));
const auto& parallel_desc = op_node->parallel_desc();
GenerateOptimizerOpConfIf(*var_op, parallel_desc.parallel_conf(), job_builder,
GenerateOptimizerOpConfIf(ctx, *var_op, parallel_desc.parallel_conf(), job_builder,
diff_lbi_of_var_out);
});
}
......
......@@ -18,10 +18,11 @@ limitations under the License.
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/operator/variable_op.h"
#include "oneflow/core/job_rewriter/job_pass.h"
namespace oneflow {
void AddOptimizerOpConf(const OpGraph& op_graph, JobBuilder* job_builder,
void AddOptimizerOpConf(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,
const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi);
float GetOptimizerWeightDecayRate(const NormalModelUpdateOpUserConf& model_update_conf,
......@@ -33,11 +34,11 @@ void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_o
class GenerateOptimizerOpConfWrapperStruct final {
public:
using Func = std::function<void(const VariableOp&, const ParallelConf&, JobBuilder*,
using Func = std::function<void(JobPassCtx*, const VariableOp&, const ParallelConf&, JobBuilder*,
const LogicalBlobId&)>;
GenerateOptimizerOpConfWrapperStruct(const Func& f) : func_(std::make_unique<Func>(f)) {}
void Call(const VariableOp& op, const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) const;
void Call(JobPassCtx* ctx, const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) const;
private:
const std::unique_ptr<const Func> func_;
......
......@@ -31,8 +31,9 @@ OperatorConf GenerateRmspropHelperVariableOpConf(const VariableOp& op, const std
return helper_variable_op;
}
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
OperatorConf mean_square_var(GenerateRmspropHelperVariableOpConf(op, "mean_square", 0.f));
OperatorConf mdupdt_op;
......
......@@ -20,8 +20,9 @@ namespace oneflow {
namespace {
void GenerateOptimizerOpConf(const VariableOp& op, const ParallelConf& parallel_conf,
JobBuilder* job_builder, const LogicalBlobId& diff_lbi_of_var_out) {
void GenerateOptimizerOpConf(JobPassCtx* ctx, const VariableOp& op,
const ParallelConf& parallel_conf, JobBuilder* job_builder,
const LogicalBlobId& diff_lbi_of_var_out) {
const auto& train_conf = job_builder->job().job_conf().train_conf();
const NormalModelUpdateOpUserConf& model_update_conf = train_conf.model_update_conf();
user_op::UserOpConfWrapperBuilder sgd_update_op_builder(op.op_name() + "_optimizer");
......
......@@ -133,24 +133,16 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDAT
template<typename T, typename G>
struct AdamUpdateKernelUtil<DeviceType::kCPU, T, G> {
static void Update(DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1,
float beta2, float epsilon, bool do_bias_correction, float weight_decay,
const float* learning_rate, const T* scale_by_ptr, const G* model_diff,
T* model, T* m, T* v, T* beta1_t, T* beta2_t);
float beta2, float epsilon, float weight_decay, const float* learning_rate,
const T* scale_by_ptr, const G* model_diff, T* model, T* m, T* v);
};
template<typename T, typename G>
void AdamUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(
DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon,
bool do_bias_correction, float weight_decay, const float* learning_rate, const T* scale_by_ptr,
const G* model_diff, T* model, T* m, T* v, T* beta1_t, T* beta2_t) {
float lr;
if (do_bias_correction) {
lr = *learning_rate * std::sqrt(1 - *beta2_t) / (1 - *beta1_t);
*beta1_t *= beta1;
*beta2_t *= beta2;
} else {
lr = *learning_rate;
}
float weight_decay, const float* learning_rate, const T* scale_by_ptr, const G* model_diff,
T* model, T* m, T* v) {
const float lr = *learning_rate;
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
FOR_RANGE(int64_t, i, 0, n) {
AdamUpdateFunctor<T, G>()(model_diff + i, model + i, m + i, v + i, scale, l1, l2, beta1, beta2,
......@@ -163,19 +155,11 @@ template struct AdamUpdateKernelUtil<DeviceType::kCPU, double, double>;
template<typename T, typename K, typename IDX>
struct IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kCPU, T, K, IDX> {
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon,
bool do_bias_correction, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* m,
T* v, T* beta1_t, T* beta2_t) {
float lr;
if (do_bias_correction) {
lr = *learning_rate * std::sqrt(1 - *beta2_t) / (1 - *beta1_t);
*beta1_t *= beta1;
*beta2_t *= beta2;
} else {
lr = *learning_rate;
}
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* m, T* v) {
const float lr = *learning_rate;
const int64_t n = *num_unique_instance * feature_size;
FOR_RANGE(int64_t, i, 0, n) {
const IDX indices_idx = i / feature_size;
......@@ -234,4 +218,20 @@ void LambUpdateKernelUtil<DeviceType::kCPU, T, G>::Update(
template struct LambUpdateKernelUtil<DeviceType::kCPU, float, float>;
template struct LambUpdateKernelUtil<DeviceType::kCPU, double, double>;
template<>
struct AdamBiasCorrectionLearningRateKernelUtil<DeviceType::kCPU> {
static void AdamBiasCorrectionLearningRate(DeviceCtx* ctx, float beta1, float beta2,
const float* learning_rate, const int64_t* train_step,
float* out);
};
void AdamBiasCorrectionLearningRateKernelUtil<DeviceType::kCPU>::AdamBiasCorrectionLearningRate(
DeviceCtx* ctx, float beta1, float beta2, const float* learning_rate, const int64_t* train_step,
float* out) {
const auto exponent = static_cast<double>(*train_step + 1);
const float beta1_power = static_cast<float>(std::pow(beta1, exponent));
const float beta2_power = static_cast<float>(std::pow(beta2, exponent));
*out = *learning_rate * sqrt(1 - beta2_power) / (1 - beta1_power);
}
} // namespace oneflow
......@@ -223,17 +223,20 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_MOMENTUM_MODEL_UPDAT
namespace {
__global__ void AdamBiasCorrectionLearningRateGpu(float beta1, float beta2,
const float* learning_rate,
const int64_t* train_step, float* out) {
const auto exponent = static_cast<double>(*train_step + 1);
const float beta1_power = static_cast<float>(pow(beta1, exponent));
const float beta2_power = static_cast<float>(pow(beta2, exponent));
*out = *learning_rate * sqrt(1 - beta2_power) / (1 - beta1_power);
}
template<typename T, typename G>
__global__ void AdamUpdateGpu(int64_t n, T scale, float l1, float l2, float beta1, float beta2,
float epsilon, bool do_bias_correction, float weight_decay,
const float* learning_rate, const T* scale_by_ptr,
const G* model_diff, T* model, T* m, T* v, T* beta1_t, T* beta2_t) {
float lr;
if (do_bias_correction) {
lr = *learning_rate * sqrt(1 - *beta2_t) / (1 - *beta1_t);
} else {
lr = *learning_rate;
}
float epsilon, float weight_decay, const float* learning_rate,
const T* scale_by_ptr, const G* model_diff, T* model, T* m, T* v) {
const float lr = *learning_rate;
if (scale_by_ptr != nullptr) { scale *= *scale_by_ptr; }
CUDA_1D_KERNEL_LOOP(i, n) {
AdamUpdateFunctor<T, G>()(model_diff + i, model + i, m + i, v + i, scale, l1, l2, beta1, beta2,
......@@ -249,18 +252,11 @@ __global__ void AdamUpdateBetaTGpu(const T beta1, const T beta2, T* beta1_t, T*
template<typename T, typename K, typename IDX>
__global__ void IndexedSlicesAdamUpdateGpu(float beta1, float beta2, float epsilon,
bool do_bias_correction, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance,
int64_t feature_size, int64_t lower_bound,
int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices,
const T* values, T* model, T* m, T* v, T* beta1_t,
T* beta2_t) {
float lr;
if (do_bias_correction) {
lr = *learning_rate * sqrt(1 - *beta2_t) / (1 - *beta1_t);
} else {
lr = *learning_rate;
}
const T* values, T* model, T* m, T* v) {
const float lr = *learning_rate;
const int64_t n = *num_unique_instance * feature_size;
CUDA_1D_KERNEL_LOOP(i, n) {
const IDX indices_idx = i / feature_size;
......@@ -299,40 +295,35 @@ __global__ void LambUpdateGpu(int64_t n, float weight_decay, const float* learni
template<typename T, typename G>
struct AdamUpdateKernelUtil<DeviceType::kGPU, T, G> {
static void Update(DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1,
float beta2, float epsilon, bool do_bias_correction, float weight_decay,
const float* learning_rate, const T* scale_by_ptr, const G* model_diff,
T* model, T* m, T* v, T* beta1_t, T* beta2_t);
float beta2, float epsilon, float weight_decay, const float* learning_rate,
const T* scale_by_ptr, const G* model_diff, T* model, T* m, T* v);
};
template<typename T, typename G>
void AdamUpdateKernelUtil<DeviceType::kGPU, T, G>::Update(
DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon,
bool do_bias_correction, float weight_decay, const float* learning_rate, const T* scale_by_ptr,
const G* model_diff, T* model, T* m, T* v, T* beta1_t, T* beta2_t) {
float weight_decay, const float* learning_rate, const T* scale_by_ptr, const G* model_diff,
T* model, T* m, T* v) {
AdamUpdateGpu<T, G><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
n, scale, l1, l2, beta1, beta2, epsilon, do_bias_correction, weight_decay, learning_rate,
scale_by_ptr, model_diff, model, m, v, beta1_t, beta2_t);
if (do_bias_correction) {
AdamUpdateBetaTGpu<T><<<1, 1, 0, ctx->cuda_stream()>>>(beta1, beta2, beta1_t, beta2_t);
}
n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate, scale_by_ptr,
model_diff, model, m, v);
}
template<typename T>
struct AdamUpdateKernelUtil<DeviceType::kGPU, T, float16> {
static void Update(DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1,
float beta2, float epsilon, bool do_bias_correction, float weight_decay,
const float* learning_rate, const T* scale_by_ptr, const float16* model_diff,
T* model, T* m, T* v, T* beta1_t, T* beta2_t);
float beta2, float epsilon, float weight_decay, const float* learning_rate,
const T* scale_by_ptr, const float16* model_diff, T* model, T* m, T* v);
};
template<typename T>
void AdamUpdateKernelUtil<DeviceType::kGPU, T, float16>::Update(
DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1, float beta2, float epsilon,
bool do_bias_correction, float weight_decay, const float* learning_rate, const T* scale_by_ptr,
const float16* model_diff, T* model, T* m, T* v, T* beta1_t, T* beta2_t) {
float weight_decay, const float* learning_rate, const T* scale_by_ptr,
const float16* model_diff, T* model, T* m, T* v) {
AdamUpdateKernelUtil<DeviceType::kGPU, T, half>::Update(
ctx, n, scale, l1, l2, beta1, beta2, epsilon, do_bias_correction, weight_decay, learning_rate,
scale_by_ptr, reinterpret_cast<const half*>(model_diff), model, m, v, beta1_t, beta2_t);
ctx, n, scale, l1, l2, beta1, beta2, epsilon, weight_decay, learning_rate, scale_by_ptr,
reinterpret_cast<const half*>(model_diff), model, m, v);
}
template struct AdamUpdateKernelUtil<DeviceType::kGPU, float, float>;
......@@ -392,27 +383,21 @@ template struct LambUpdateKernelUtil<DeviceType::kGPU, float, float16>;
template<typename T, typename K, typename IDX>
struct IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kGPU, T, K, IDX> {
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon,
bool do_bias_correction, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* m,
T* v, T* beta1_t, T* beta2_t);
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* m, T* v);
};
template<typename T, typename K, typename IDX>
void IndexedSlicesAdamMdUpdateKernelUtil<DeviceType::kGPU, T, K, IDX>::Update(
DeviceCtx* ctx, float beta1, float beta2, float epsilon, bool do_bias_correction,
int64_t num_instance, int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices, const T* values,
T* model, T* m, T* v, T* beta1_t, T* beta2_t) {
IndexedSlicesAdamUpdateGpu<T, K><<<BlocksNum4ThreadsNum(num_instance * feature_size),
kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
beta1, beta2, epsilon, do_bias_correction, feature_size, lower_bound, upper_bound,
num_unique_instance, learning_rate, indices, values, model, m, v, beta1_t, beta2_t);
if (do_bias_correction) {
AdamUpdateBetaTGpu<T><<<1, 1, 0, ctx->cuda_stream()>>>(beta1, beta2, beta1_t, beta2_t);
}
DeviceCtx* ctx, float beta1, float beta2, float epsilon, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* m, T* v) {
IndexedSlicesAdamUpdateGpu<T, K>
<<<BlocksNum4ThreadsNum(num_instance * feature_size), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(beta1, beta2, epsilon, feature_size, lower_bound, upper_bound,
num_unique_instance, learning_rate, indices, values, model, m, v);
}
#define INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_GPU(val_type_pair, key_type_pair, \
......@@ -424,4 +409,18 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KE
FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ, INT_DATA_TYPE_SEQ);
#undef INSTANTIATE_INDEXED_SLICES_ADAM_MODEL_UPDATE_KERNEL_UTIL_GPU
template<>
struct AdamBiasCorrectionLearningRateKernelUtil<DeviceType::kGPU> {
static void AdamBiasCorrectionLearningRate(DeviceCtx* ctx, float beta1, float beta2,
const float* learning_rate, const int64_t* train_step,
float* out);
};
void AdamBiasCorrectionLearningRateKernelUtil<DeviceType::kGPU>::AdamBiasCorrectionLearningRate(
DeviceCtx* ctx, float beta1, float beta2, const float* learning_rate, const int64_t* train_step,
float* out) {
AdamBiasCorrectionLearningRateGpu<<<1, 1, 0, ctx->cuda_stream()>>>(beta1, beta2, learning_rate,
train_step, out);
}
} // namespace oneflow
......@@ -148,18 +148,16 @@ struct IndexedSlicesMomentumMdUpdateKernelUtil {
template<DeviceType device_type, typename T, typename G>
struct AdamUpdateKernelUtil {
static void Update(DeviceCtx* ctx, int64_t n, T scale, float l1, float l2, float beta1,
float beta2, float epsilon, bool do_bias_correction, float weight_decay,
const float* learning_rate, const T* scale_by_ptr, const G* model_diff,
T* model, T* m, T* v, T* beta1_t, T* beta2_t);
float beta2, float epsilon, float weight_decay, const float* learning_rate,
const T* scale_by_ptr, const G* model_diff, T* model, T* m, T* v);
};
template<DeviceType device_type, typename T, typename K, typename IDX>
struct IndexedSlicesAdamMdUpdateKernelUtil {
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon,
bool do_bias_correction, int64_t num_instance, int64_t feature_size,
int64_t lower_bound, int64_t upper_bound, const IDX* num_unique_instance,
const float* learning_rate, const K* indices, const T* values, T* model, T* m,
T* v, T* beta1_t, T* beta2_t);
static void Update(DeviceCtx* ctx, float beta1, float beta2, float epsilon, int64_t num_instance,
int64_t feature_size, int64_t lower_bound, int64_t upper_bound,
const IDX* num_unique_instance, const float* learning_rate, const K* indices,
const T* values, T* model, T* m, T* v);
};
template<DeviceType device_type, typename T, typename G>
......@@ -171,6 +169,14 @@ struct LambUpdateKernelUtil {
T* norm_buffer, T* beta1_t, T* beta2_t);
};
template<DeviceType device_type>
struct AdamBiasCorrectionLearningRateKernelUtil {
public:
static void AdamBiasCorrectionLearningRate(DeviceCtx* ctx, float beta1, float beta2,
const float* learning_rate, const int64_t* train_step,
float* out);
};
#endif
} // namespace oneflow
......@@ -348,7 +348,6 @@ class AdamUpdateKernel final : public user_op::OpKernel {
const auto beta1 = ctx->Attr<float>("beta1");
const auto beta2 = ctx->Attr<float>("beta2");
const auto epsilon = ctx->Attr<float>("epsilon");
const auto do_bias_correction = ctx->Attr<bool>("do_bias_correction");
const auto weight_decay = ctx->Attr<float>("weight_decay");
const T* scale_by_ptr = nullptr;
if (ctx->user_op_conf().has_input("scale_by_tensor", 0)) {
......@@ -357,19 +356,10 @@ class AdamUpdateKernel final : public user_op::OpKernel {
CHECK_EQ(scale_by_tensor->shape().elem_cnt(), 1);
scale_by_ptr = scale_by_tensor->dptr<T>();
}
T* beta1_t_ptr = nullptr;
T* beta2_t_ptr = nullptr;
if (do_bias_correction) {
user_op::Tensor* beta1_t = ctx->Tensor4ArgNameAndIndex("beta1_t", 0);
beta1_t_ptr = beta1_t->mut_dptr<T>();
user_op::Tensor* beta2_t = ctx->Tensor4ArgNameAndIndex("beta2_t", 0);
beta2_t_ptr = beta2_t->mut_dptr<T>();
}
AdamUpdateKernelUtil<device_type, T, G>::Update(
ctx->device_ctx(), model->shape().elem_cnt(), static_cast<T>(scale), l1, l2, beta1, beta2,
epsilon, do_bias_correction, weight_decay, learning_rate->dptr<float>(), scale_by_ptr,
model_diff->dptr<G>(), model->mut_dptr<T>(), m->mut_dptr<T>(), v->mut_dptr<T>(),
beta1_t_ptr, beta2_t_ptr);
epsilon, weight_decay, learning_rate->dptr<float>(), scale_by_ptr, model_diff->dptr<G>(),
model->mut_dptr<T>(), m->mut_dptr<T>(), v->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }
};
......@@ -413,15 +403,6 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel {
const auto beta1 = ctx->Attr<float>("beta1");
const auto beta2 = ctx->Attr<float>("beta2");
const auto epsilon = ctx->Attr<float>("epsilon");
const auto do_bias_correction = ctx->Attr<bool>("do_bias_correction");
T* beta1_t_ptr = nullptr;
T* beta2_t_ptr = nullptr;
if (do_bias_correction) {
user_op::Tensor* beta1_t = ctx->Tensor4ArgNameAndIndex("beta1_t", 0);
beta1_t_ptr = beta1_t->mut_dptr<T>();
user_op::Tensor* beta2_t = ctx->Tensor4ArgNameAndIndex("beta2_t", 0);
beta2_t_ptr = beta2_t->mut_dptr<T>();
}
auto* kernel_state = dynamic_cast<IndexedSlicesUpdateOpKernelState*>(state);
CHECK_NOTNULL(kernel_state);
CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower());
......@@ -441,12 +422,12 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel {
buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(),
buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes());
MdUpdateUtilT::Update(ctx->device_ctx(), beta1, beta2, epsilon, do_bias_correction, num_indices,
feature_size, kernel_state->lower(), kernel_state->upper(),
MdUpdateUtilT::Update(ctx->device_ctx(), beta1, beta2, epsilon, num_indices, feature_size,
kernel_state->lower(), kernel_state->upper(),
buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr<float>(),
buffer_manager.UniqueDiffIndicesPtr(),
buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr<T>(),
m->mut_dptr<T>(), v->mut_dptr<T>(), beta1_t_ptr, beta2_t_ptr);
m->mut_dptr<T>(), v->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }
};
......@@ -567,6 +548,35 @@ REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kGPU, float, float);
REGISTER_LAMB_UPDATE_KERNEL(DeviceType::kGPU, double, double);
#endif // WITH_CUDA
template<DeviceType device_type>
class AdamBiasCorrectionLearningRateKernel final : public user_op::OpKernel {
public:
AdamBiasCorrectionLearningRateKernel() = default;
~AdamBiasCorrectionLearningRateKernel() override = default;
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0);
const user_op::Tensor* train_step = ctx->Tensor4ArgNameAndIndex("train_step", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const auto beta1 = ctx->Attr<float>("beta1");
const auto beta2 = ctx->Attr<float>("beta2");
AdamBiasCorrectionLearningRateKernelUtil<device_type>::AdamBiasCorrectionLearningRate(
ctx->device_ctx(), beta1, beta2, learning_rate->dptr<float>(), train_step->dptr<int64_t>(),
out->mut_dptr<float>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return true; }
};
#define REGISTER_ADAM_BIAS_CORRECTION_LEARNING_RATE_KERNEL(device) \
REGISTER_USER_KERNEL("adam_bias_correction_learning_rate") \
.SetCreateFn<AdamBiasCorrectionLearningRateKernel<device>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device));
REGISTER_ADAM_BIAS_CORRECTION_LEARNING_RATE_KERNEL(DeviceType::kCPU)
#ifdef WITH_CUDA
REGISTER_ADAM_BIAS_CORRECTION_LEARNING_RATE_KERNEL(DeviceType::kGPU)
#endif // WITH_CUDA
} // namespace
} // namespace oneflow
......@@ -115,7 +115,6 @@ Maybe<void> InferIndexedSlicesMomentumUpdateTensorDesc(user_op::InferContext* ct
Maybe<void> InferAdamUpdateTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc* model = ctx->TensorDesc4ArgNameAndIndex("model", 0);
const DataType data_type = model->data_type();
const Shape& shape = model->shape();
const user_op::TensorDesc* model_diff = ctx->TensorDesc4ArgNameAndIndex("model_diff", 0);
CHECK_EQ_OR_RETURN(model_diff->shape(), shape);
......@@ -129,23 +128,11 @@ Maybe<void> InferAdamUpdateTensorDesc(user_op::InferContext* ctx) {
const auto* scale_by_tensor = ctx->TensorDesc4ArgNameAndIndex("scale_by_tensor", 0);
JUST(CheckScalarTensorDesc(scale_by_tensor, model->data_type()));
}
if (ctx->Attr<bool>("do_bias_correction")) {
CHECK_OR_RETURN(ctx->user_op_conf().has_input("beta1_t", 0));
CHECK_OR_RETURN(ctx->user_op_conf().has_input("beta2_t", 0));
const user_op::TensorDesc* beta1_t = ctx->TensorDesc4ArgNameAndIndex("beta1_t", 0);
const user_op::TensorDesc* beta2_t = ctx->TensorDesc4ArgNameAndIndex("beta2_t", 0);
JUST(CheckScalarTensorDesc(beta1_t, data_type));
JUST(CheckScalarTensorDesc(beta2_t, data_type));
} else {
CHECK_OR_RETURN(!ctx->user_op_conf().has_input("beta1_t", 0));
CHECK_OR_RETURN(!ctx->user_op_conf().has_input("beta2_t", 0));
}
return Maybe<void>::Ok();
}
Maybe<void> InferIndexedSlicesAdamUpdateTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc* model = ctx->TensorDesc4ArgNameAndIndex("model", 0);
const DataType data_type = model->data_type();
const user_op::TensorDesc* model_diff_indices =
ctx->TensorDesc4ArgNameAndIndex("model_diff_indices", 0);
const user_op::TensorDesc* model_diff_values =
......@@ -153,17 +140,6 @@ Maybe<void> InferIndexedSlicesAdamUpdateTensorDesc(user_op::InferContext* ctx) {
JUST(CheckIndexedSlicesModelDiffDesc(model, model_diff_indices, model_diff_values));
const user_op::TensorDesc* learning_rate = ctx->TensorDesc4ArgNameAndIndex("learning_rate", 0);
JUST(CheckLearningRateTenserDesc(learning_rate));
if (ctx->Attr<bool>("do_bias_correction")) {
CHECK_OR_RETURN(ctx->user_op_conf().has_input("beta1_t", 0));
CHECK_OR_RETURN(ctx->user_op_conf().has_input("beta2_t", 0));
const user_op::TensorDesc* beta1_t = ctx->TensorDesc4ArgNameAndIndex("beta1_t", 0);
const user_op::TensorDesc* beta2_t = ctx->TensorDesc4ArgNameAndIndex("beta2_t", 0);
JUST(CheckScalarTensorDesc(beta1_t, data_type));
JUST(CheckScalarTensorDesc(beta2_t, data_type));
} else {
CHECK_OR_RETURN(!ctx->user_op_conf().has_input("beta1_t", 0));
CHECK_OR_RETURN(!ctx->user_op_conf().has_input("beta2_t", 0));
}
return Maybe<void>::Ok();
}
......@@ -208,10 +184,6 @@ void AdamInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifie
SetInputArgModifierMutable(GetInputArgModifierFn, "model", 0);
SetInputArgModifierMutable(GetInputArgModifierFn, "m", 0);
SetInputArgModifierMutable(GetInputArgModifierFn, "v", 0);
if (conf.attr<bool>("do_bias_correction")) {
SetInputArgModifierMutable(GetInputArgModifierFn, "beta1_t", 0);
SetInputArgModifierMutable(GetInputArgModifierFn, "beta2_t", 0);
}
}
void LambInputArgModifyFn(const user_op::GetInputArgModifier& GetInputArgModifierFn,
......@@ -359,15 +331,12 @@ REGISTER_USER_OP("adam_update")
.OptionalInput("scale_by_tensor")
.Input("m")
.Input("v")
.OptionalInput("beta1_t")
.OptionalInput("beta2_t")
.Attr<double>("scale", 1.0)
.Attr<float>("l1", 0.0)
.Attr<float>("l2", 0.0)
.Attr<float>("beta1", 0.9)
.Attr<float>("beta2", 0.999)
.Attr<float>("epsilon", 1e-8)
.Attr<bool>("do_bias_correction", false)
.Attr<float>("weight_decay", 0.0)
.SetTensorDescInferFn(InferAdamUpdateTensorDesc)
.SetBatchAxisInferFn(user_op::BatchAxisInferFnUtil::NaiveInferBatchAxis)
......@@ -393,12 +362,9 @@ REGISTER_USER_OP("indexed_slices_adam_update")
.Input("learning_rate")
.Input("m")
.Input("v")
.OptionalInput("beta1_t")
.OptionalInput("beta2_t")
.Attr<float>("beta1", 0.9)
.Attr<float>("beta2", 0.999)
.Attr<float>("epsilon", 1e-8)
.Attr<bool>("do_bias_correction", false)
.SetTensorDescInferFn(InferIndexedSlicesAdamUpdateTensorDesc)
.SetBatchAxisInferFn(user_op::BatchAxisInferFnUtil::NaiveInferBatchAxis)
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
......@@ -408,10 +374,6 @@ REGISTER_USER_OP("indexed_slices_adam_update")
std::vector<user_op::OpArg> broadcast_args;
broadcast_args.emplace_back("learning_rate", 0);
broadcast_args.emplace_back("model_diff_indices", 0);
if (ctx->Attr<bool>("do_bias_correction")) {
broadcast_args.emplace_back("beta1_t", 0);
broadcast_args.emplace_back("beta2_t", 0);
}
ctx->NewBuilder()
.Broadcast(broadcast_args)
.Broadcast(user_op::OpArg("model_diff_values", 0))
......@@ -454,6 +416,23 @@ REGISTER_USER_OP("lamb_update")
// every bn has sbp broadcast signature
.SetInputArgModifyFn(LambInputArgModifyFn);
REGISTER_USER_OP("adam_bias_correction_learning_rate")
.Input("learning_rate")
.Input("train_step")
.Output("out")
.Attr<float>("beta1", 0.9)
.Attr<float>("beta2", 0.999)
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->TensorDesc4ArgNameAndIndex("out", 0) =
*ctx->TensorDesc4ArgNameAndIndex("learning_rate", 0);
return Maybe<void>::Ok();
})
.SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> {
ctx->BatchAxis4ArgNameAndIndex("out", 0)->clear_value();
return Maybe<void>::Ok();
});
// every bn has sbp broadcast signature
} // namespace
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册