From 5e5ea853bfbe1ac168dfb4e26a164ba4834a6acd Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 4 Aug 2020 03:43:26 +0000 Subject: [PATCH] add piecewise decay --- ppocr/data/rec/img_tools.py | 2 +- ppocr/optimizer.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index d41abd9b..0835603b 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -360,7 +360,7 @@ def process_image(img, text = char_ops.encode(label) if len(text) == 0 or len(text) > max_text_length: logger.info( - "Warning in ppocr/data/rec/img_tools.py:line362: Wrong data type." + "Warning in ppocr/data/rec/img_tools.py: Wrong data type." "Excepted string with length between 1 and {}, but " "got '{}'. Label is '{}'".format(max_text_length, len(text), label)) diff --git a/ppocr/optimizer.py b/ppocr/optimizer.py index f5f49583..eb1037c2 100755 --- a/ppocr/optimizer.py +++ b/ppocr/optimizer.py @@ -36,17 +36,28 @@ def AdamDecay(params, parameter_list=None): l2_decay = params.get("l2_decay", 0.0) if 'decay' in params: + supported_decay_mode = ["cosine_decay", "piecewise_decay"] params = params['decay'] decay_mode = params['function'] - step_each_epoch = params['step_each_epoch'] - total_epoch = params['total_epoch'] + assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format( + supported_decay_mode, decay_mode) + if decay_mode == "cosine_decay": + step_each_epoch = params['step_each_epoch'] + total_epoch = params['total_epoch'] base_lr = fluid.layers.cosine_decay( learning_rate=base_lr, step_each_epoch=step_each_epoch, epochs=total_epoch) - else: - logger.info("Only support Cosine decay currently") + elif decay_mode == "piecewise_decay": + boundaries = params["boundaries"] + decay_rate = params["decay_rate"] + values = [ + base_lr * decay_rate**idx + for idx in range(len(boundaries) + 1) + ] + base_lr = fluid.layers.piecewise_decay(boundaries, values) + optimizer = fluid.optimizer.Adam( learning_rate=base_lr, beta1=beta1, -- GitLab