diff --git a/fluid/ocr/ctc_train.py b/fluid/ocr/ctc_train.py index 66949d527ab4b439ea055fe6a1e726e0281aaeda..2e7cf5d53058093b8aa64d51f4421562621198e8 100644 --- a/fluid/ocr/ctc_train.py +++ b/fluid/ocr/ctc_train.py @@ -24,7 +24,7 @@ from utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('batch_size', int, 2, "Minibatch size.") +add_arg('batch_size', int, 16, "Minibatch size.") add_arg('pass_num', int, 16, "# of training epochs.") add_arg('learning_rate', float, 1.0e-3, "Learning rate.") add_arg('l2', float, 0.0005, "L2 regularizer.") @@ -121,7 +121,7 @@ def train(args, data_reader=dummy_reader): avg_cost = fluid.layers.mean(x=cost) optimizer = fluid.optimizer.Momentum( learning_rate=args.learning_rate, momentum=args.momentum) - opts = optimizer.minimize(cost) + optimizer.minimize(avg_cost) # decoder and evaluator decoded_out = fluid.layers.ctc_greedy_decoder( input=fc_out, blank=num_classes)