提交 8b6f374f 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #2216 from kuke/enable_grad_clipping_dev

Enable the setting of global gradient clipping threshold
......@@ -161,6 +161,7 @@ void AdaDeltaParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1LU) << "Sparse update is not supported";
BaseMatrix& value = *vecs[PARAMETER_VALUE];
BaseMatrix& grad = *vecs[PARAMETER_GRADIENT];
BaseMatrix& mom = *vecs[PARAMETER_MOMENTUM];
......@@ -265,6 +266,7 @@ void AdamParameterOptimizer::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
CHECK(sparseId == -1UL) << "Sparse update is not supported";
real beta1_power = std::pow(beta1_, step_);
real beta2_power = std::pow(beta2_, step_);
real learningRate = config.learning_rate() * learningRate_;
......@@ -303,18 +305,25 @@ void AdamaxParameterOptimizer::update(const VectorPtr vecs[],
void OptimizerWithGradientClipping::update(const VectorPtr vecs[],
const ParameterConfig& config,
size_t sparseId) const {
real globalThreshold = optConfig_.gradient_clipping_threshold();
real localThreshold = config.gradient_clipping_threshold();
// Use local gradient clipping threshold if it's enabled,
// otherwise using the global one.
real threshold = localThreshold > 0.0f ? localThreshold : globalThreshold;
std::string field = localThreshold > 0.0f ? "local" : "global";
real maxAbsGrad = vecs[PARAMETER_GRADIENT]->getAbsMax();
if (maxAbsGrad > config.gradient_clipping_threshold()) {
if (maxAbsGrad > threshold) {
if (FLAGS_log_clipping) {
real avgAbsGrad = vecs[PARAMETER_GRADIENT]->getAbsSum() /
vecs[PARAMETER_GRADIENT]->getSize();
LOG(INFO) << "parameter=" << config.name() << " need clipping,"
<< " max grad=" << maxAbsGrad << " avg grad=" << avgAbsGrad;
LOG(INFO) << "parameter=" << config.name() << " need clipping by "
<< field << " threshold=" << threshold
<< ", max grad=" << maxAbsGrad << ", avg grad=" << avgAbsGrad;
}
vecs[PARAMETER_GRADIENT]->clip(-config.gradient_clipping_threshold(),
config.gradient_clipping_threshold());
vecs[PARAMETER_GRADIENT]->clip(-threshold, threshold);
}
optimizer_->update(vecs, config, sparseId);
}
......
......@@ -131,7 +131,8 @@ ParameterOptimizer* OptimizerWithRegularizer::create(
bool inPserver) {
ParameterOptimizer* optimizer =
ParameterOptimizer::create(optConfig, inPserver);
if (paraConfig.gradient_clipping_threshold() > 0.0f &&
if ((optConfig.gradient_clipping_threshold() > 0.0f ||
paraConfig.gradient_clipping_threshold() > 0.0f) &&
!dynamic_cast<AddOptimizer*>(optimizer)) {
optimizer = new OptimizerWithGradientClipping(optConfig, optimizer);
}
......
......@@ -167,6 +167,7 @@ public:
}
parameterTypes_.push_back(type);
}
real getLearningRate() const { return learningRate_; }
virtual void setNoDecay() { applyDecay_ = false; }
......@@ -201,6 +202,7 @@ protected:
* so, if lr change in StartBatch, please assign to learningRate_
*/
real learningRate_;
std::unique_ptr<LearningRateScheduler> learningRateScheduler_;
int64_t pass_; // current training pass (starting from 0)
bool firstTime_;
......
......@@ -128,6 +128,9 @@ message OptimizationConfig {
// when async_lagged_grad_discard_ratio * num_gradient_servers commit passed,
// current async gradient will be discard silently.
optional double async_lagged_grad_discard_ratio = 37 [default = 1.5];
// global threshold for gradient clipping
optional double gradient_clipping_threshold = 38 [default = 0.0];
};
message TrainerConfig {
......
......@@ -3377,6 +3377,7 @@ settings = dict(
algorithm='async_sgd',
async_lagged_grad_discard_ratio=1.5,
learning_method='momentum',
gradient_clipping_threshold=None,
num_batches_per_send_parameter=None,
num_batches_per_get_parameter=None,
center_parameter_update_method=None,
......
......@@ -408,7 +408,8 @@ def settings(batch_size,
args = [
'batch_size', 'learning_rate', 'learning_rate_decay_a',
'learning_rate_decay_b', 'learning_rate_schedule', 'learning_rate_args'
'learning_rate_decay_b', 'learning_rate_schedule', 'learning_rate_args',
'gradient_clipping_threshold'
]
kwargs = dict()
kwargs['algorithm'] = algorithm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册