未验证 提交 f15d7ff3 编写于 作者: W whs 提交者: GitHub

Fix ce of ocr. (#1133)

* Fix ce of ocr.

* Remove debug print
上级 b8e6805d
python ctc_train.py --batch_size=128 --total_step=10000 -eval_period=10000 --log_period=10000 --use_gpu=True | python _ce.py
export ce_mode=1
rm *factor.txt
python ctc_train.py --batch_size=32 --total_step=30000 --eval_period=30000 --log_period=30000 --use_gpu=True 1> ./tmp.log
cat tmp.log | python _ce.py
rm tmp.log
......@@ -25,7 +25,12 @@ class DataGenerator(object):
def __init__(self):
pass
def train_reader(self, img_root_dir, img_label_list, batchsize, cycle):
def train_reader(self,
img_root_dir,
img_label_list,
batchsize,
cycle,
shuffle=True):
'''
Reader interface for training.
......@@ -42,15 +47,12 @@ class DataGenerator(object):
'''
img_label_lines = []
if batchsize == 1:
to_file = "tmp.txt"
to_file = "tmp.txt"
if not shuffle:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file
elif batchsize == 1:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
print "cmd: " + cmd
os.system(cmd)
print "finish batch shuffle"
img_label_lines = open(to_file, 'r').readlines()
else:
to_file = "tmp.txt"
#cmd1: partial shuffle
cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
#cmd2: batch merge and shuffle
......@@ -62,10 +64,9 @@ class DataGenerator(object):
) + " * 4) {for(i = 0; i < " + str(
batchsize
) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
print "cmd: " + cmd
os.system(cmd)
print "finish batch shuffle"
img_label_lines = open(to_file, 'r').readlines()
os.system(cmd)
print "finish batch shuffle"
img_label_lines = open(to_file, 'r').readlines()
def reader():
sizes = len(img_label_lines) / batchsize
......@@ -191,8 +192,11 @@ def train(batch_size, train_images_dir=None, train_list_file=None, cycle=False):
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if train_list_file is None:
train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
return generator.train_reader(train_images_dir, train_list_file, batch_size,
cycle)
shuffle = True
if 'ce_mode' in os.environ:
shuffle = False
return generator.train_reader(
train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
def test(batch_size=1, test_images_dir=None, test_list_file=None):
......
......@@ -63,6 +63,10 @@ def train(args, data_reader=ctc_reader):
if args.use_gpu:
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
if 'ce_mode' in os.environ:
fluid.default_startup_program().random_seed = 90
exe.run(fluid.default_startup_program())
# load init model
......@@ -148,7 +152,6 @@ def train(args, data_reader=ctc_reader):
args.batch_size))
print "kpis train_acc %f" % (
1 - total_seq_error / (args.log_period * args.batch_size))
sys.stdout.flush()
total_loss = 0.0
total_seq_error = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册