From c9ce72a73eb67a13cdb1a42f2174aeb79d358ff8 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 5 Mar 2018 11:15:37 +0800 Subject: [PATCH] Use reduce_sum op instead of mean_op --- fluid/ocr_recognition/crnn_ctc_model.py | 21 +++++++-------------- fluid/ocr_recognition/ctc_train.py | 19 ++++--------------- 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/fluid/ocr_recognition/crnn_ctc_model.py b/fluid/ocr_recognition/crnn_ctc_model.py index 192ef6a9..719c0158 100644 --- a/fluid/ocr_recognition/crnn_ctc_model.py +++ b/fluid/ocr_recognition/crnn_ctc_model.py @@ -54,6 +54,7 @@ def ocr_convs(input, tmp = input tmp = conv_bn_pool( tmp, 2, [16, 16], param=w1, bias=b, param_0=w0, is_test=is_test) + tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test) tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test) tmp = conv_bn_pool(tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test) @@ -142,23 +143,19 @@ def ctc_train_net(images, label, args, num_classes): gradient_clip=gradient_clip) cost = fluid.layers.warpctc( - input=fc_out, - label=label, - # size=num_classes + 1, - blank=num_classes, - norm_by_times=True) - avg_cost = fluid.layers.mean(x=cost) + input=fc_out, label=label, blank=num_classes, norm_by_times=True) + sum_cost = fluid.layers.reduce_sum(cost) optimizer = fluid.optimizer.Momentum( learning_rate=args.learning_rate, momentum=args.momentum) - optimizer.minimize(avg_cost) - # decoder and evaluator + optimizer.minimize(sum_cost) + decoded_out = fluid.layers.ctc_greedy_decoder( input=fc_out, blank=num_classes) casted_label = fluid.layers.cast(x=label, dtype='int64') error_evaluator = fluid.evaluator.EditDistance( input=decoded_out, label=casted_label) - return avg_cost, error_evaluator + return sum_cost, error_evaluator def ctc_infer(images, num_classes): @@ -176,10 +173,6 @@ def ctc_eval(images, label, num_classes): input=decoded_out, label=casted_label) cost = fluid.layers.warpctc( - input=fc_out, - label=label, - #size=num_classes + 1, - blank=num_classes, - norm_by_times=True) + input=fc_out, label=label, blank=num_classes, norm_by_times=True) return error_evaluator, cost diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py index b599dbec..85b1d2e7 100644 --- a/fluid/ocr_recognition/ctc_train.py +++ b/fluid/ocr_recognition/ctc_train.py @@ -34,7 +34,6 @@ def load_parameter(place): t.set(params[name], place) - def train(args, data_reader=dummy_reader): """OCR CTC training""" num_classes = data_reader.num_classes() @@ -42,7 +41,7 @@ def train(args, data_reader=dummy_reader): # define network images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1) - avg_cost, error_evaluator = ctc_train_net(images, label, args, num_classes) + sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes) # data reader train_reader = data_reader.train(args.batch_size) test_reader = data_reader.test() @@ -53,20 +52,10 @@ def train(args, data_reader=dummy_reader): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - load_parameter(place) + #load_parameter(place) inference_program = fluid.io.get_inference_program(error_evaluator) - # evaluate model on test data - error_evaluator.reset(exe) - for data in test_reader(): - exe.run(inference_program, feed=get_feeder_data(data, place)) - _, test_seq_error = error_evaluator.eval(exe) - print "\nEnd pass[%d]; Test seq error: %s.\n" % ( - -1, str(test_seq_error[0])) - - - for pass_id in range(args.pass_num): error_evaluator.reset(exe) batch_id = 1 @@ -77,7 +66,7 @@ def train(args, data_reader=dummy_reader): batch_loss, _, batch_seq_error = exe.run( fluid.default_main_program(), feed=get_feeder_data(data, place), - fetch_list=[avg_cost] + error_evaluator.metrics) + fetch_list=[sum_cost] + error_evaluator.metrics) total_loss += batch_loss[0] total_seq_error += batch_seq_error[0] if batch_id % 10 == 1: @@ -85,7 +74,7 @@ def train(args, data_reader=dummy_reader): sys.stdout.flush() if batch_id % args.log_period == 1: print "\nPass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % ( - pass_id, batch_id, total_loss / batch_id, total_seq_error / (batch_id * args.batch_size)) + pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size)) sys.stdout.flush() batch_id += 1 -- GitLab