diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md index 92cb4ed6994d186a0591f4552150b8e8d31c7f15..3f8f7ff1674d63e721d7ad2ced31bf771b0183eb 100644 --- a/doc/doc_ch/config.md +++ b/doc/doc_ch/config.md @@ -42,6 +42,7 @@ | name | 优化器类名 | Adam | 目前支持`Momentum`,`Adam`,`RMSProp`, 见[ppocr/optimizer/optimizer.py](../../ppocr/optimizer/optimizer.py) | | beta1 | 设置一阶矩估计的指数衰减率 | 0.9 | \ | | beta2 | 设置二阶矩估计的指数衰减率 | 0.999 | \ | +| clip_norm | 所允许的二范数最大值 | | \ | | **lr** | 设置学习率decay方式 | - | \ | | name | 学习率decay类名 | Cosine | 目前支持`Linear`,`Cosine`,`Step`,`Piecewise`, 见[ppocr/optimizer/learning_rate.py](../../ppocr/optimizer/learning_rate.py) | | learning_rate | 基础学习率 | 0.001 | \ | diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md index ada1678e4ed059bbd3af3a5fdee42afaed1fce01..28ebb6e830369447395c661cbcc76aaf067a91d9 100644 --- a/doc/doc_en/config_en.md +++ b/doc/doc_en/config_en.md @@ -41,6 +41,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example | name | Optimizer class name | Adam | Currently supports`Momentum`,`Adam`,`RMSProp`, see [ppocr/optimizer/optimizer.py](../../ppocr/optimizer/optimizer.py) | | beta1 | Set the exponential decay rate for the 1st moment estimates | 0.9 | \ | | beta2 | Set the exponential decay rate for the 2nd moment estimates | 0.999 | \ | +| clip_norm | The maximum norm value | - | \ | | **lr** | Set the learning rate decay method | - | \ | | name | Learning rate decay class name | Cosine | Currently supports`Linear`,`Cosine`,`Step`,`Piecewise`, see[ppocr/optimizer/learning_rate.py](../../ppocr/optimizer/learning_rate.py) | | learning_rate | Set the base learning rate | 0.001 | \ | diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index 6413ae959200c25d6c17b1ec93217c0e8b0bf269..c729103a700a59764bda4f53dd68d3958172ca57 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -16,8 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals - import copy +import paddle __all__ = ['build_optimizer'] @@ -49,7 +49,13 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): # step3 build optimizer optim_name = config.pop('name') + if 'clip_norm' in config: + clip_norm = config.pop('clip_norm') + grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) + else: + grad_clip = None optim = getattr(optimizer, optim_name)(learning_rate=lr, weight_decay=reg, + grad_clip=grad_clip, **config) return optim(parameters), lr diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 2519e4e309f651dbbaebecfe8533c3eb393d47cb..8215b92d8c8d05c2b3c2e95ac989bf4ea011310b 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -30,18 +30,25 @@ class Momentum(object): regularization (WeightDecayRegularizer, optional) - The strategy of regularization. """ - def __init__(self, learning_rate, momentum, weight_decay=None, **args): + def __init__(self, + learning_rate, + momentum, + weight_decay=None, + grad_clip=None, + **args): super(Momentum, self).__init__() self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay + self.grad_clip = grad_clip def __call__(self, parameters): opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, - parameters=parameters, - weight_decay=self.weight_decay) + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + parameters=parameters) return opt @@ -96,10 +103,11 @@ class RMSProp(object): def __init__(self, learning_rate, - momentum, + momentum=0.0, rho=0.95, epsilon=1e-6, weight_decay=None, + grad_clip=None, **args): super(RMSProp, self).__init__() self.learning_rate = learning_rate @@ -107,6 +115,7 @@ class RMSProp(object): self.rho = rho self.epsilon = epsilon self.weight_decay = weight_decay + self.grad_clip = grad_clip def __call__(self, parameters): opt = optim.RMSProp( @@ -115,5 +124,6 @@ class RMSProp(object): rho=self.rho, epsilon=self.epsilon, weight_decay=self.weight_decay, + grad_clip=self.grad_clip, parameters=parameters) return opt