提交 2b8e12dd 编写于 作者: B breezedeus

optimize model directory names

上级 d661eb19
DATA_ROOT_DIR = data/sample-data
REC_DATA_ROOT_DIR = data/sample-data-lst
# ['conv', 'conv-lite-rnn', 'densenet', 'densenet-lite']
EMB_MODEL_TYPE = densenet-lite
SEQ_MODEL_TYPE = fc
MODEL_NAME = $(EMB_MODEL_TYPE)-$(SEQ_MODEL_TYPE)
# 产生 *.lst 文件
gen-lst:
python scripts/im2rec.py --list --num-label 20 --chunks 1 --train-idx-fp data/selected/train.txt --test-idx-fp data/selected/test.txt --prefix data/selected-lst/selected-data
python scripts/im2rec.py --list --num-label 20 --chunks 1 \
--train-idx-fp $(DATA_ROOT_DIR)/train.txt --test-idx-fp $(DATA_ROOT_DIR)/test.txt --prefix $(REC_DATA_ROOT_DIR)/sample-data
# 利用 *.lst 文件产生 *.idx 和 *.rec 文件。
# 真正的图片文件存储在 `examples` 目录,可通过 `--root` 指定。
gen-rec:
python scripts/im2rec.py --pack-label --color 1 --num-thread 1 --prefix data/selected-lst --root data/selected
python scripts/im2rec.py --pack-label --color 1 --num-thread 1 --prefix $(REC_DATA_ROOT_DIR) --root examples
# 训练模型
train:
python scripts/cnocr_train.py --gpu 0 --emb_model_type $(EMB_MODEL_TYPE) --seq_model_type $(SEQ_MODEL_TYPE) --optimizer adam --epoch 50 --lr 1e-5 --train_file data/selected-lst/selected-data_train --test_file data/selected-lst/selected-data_test
python scripts/cnocr_train.py --gpu 0 --emb_model_type $(EMB_MODEL_TYPE) --seq_model_type $(SEQ_MODEL_TYPE) \
--optimizer adam --epoch 20 --lr 1e-4 \
--train_file $(REC_DATA_ROOT_DIR)/sample-data_train --test_file $(REC_DATA_ROOT_DIR)/sample-data_test
evaluate:
python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch 2 -v -i data/selected/test.txt --image-prefix-dir data/selected --batch-size 128 -o evaluate/$(MODEL_NAME)
python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch 1 -v -i $(DATA_ROOT_DIR)/test.txt \
--image-prefix-dir examples --batch-size 128 -o evaluate/$(MODEL_NAME)
predict:
python scripts/cnocr_predict.py --model_name $(MODEL_NAME) --file examples/rand_cn1.png
......
......@@ -126,9 +126,7 @@ class CnOcr(object):
self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]
root = os.path.join(root, __version__)
self._model_dir = os.path.join(
root, '%s-%04d' % (self._model_name, self._model_epoch)
)
self._model_dir = os.path.join(root, self._model_name)
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(
os.path.join(self._model_dir, 'label_cn.txt')
......@@ -161,8 +159,6 @@ class CnOcr(object):
if file_prepared:
return
if os.path.exists(model_dir):
os.removedirs(model_dir)
get_model_file(model_dir)
def _get_module(self):
......
......@@ -13,10 +13,10 @@ root_url = (
)
# name: (epochs, url)
AVAILABLE_MODELS = {
'conv-lstm': (50, root_url + '/conv-lstm-0050.zip'),
'conv-lite-lstm': (45, root_url + '/conv-lite-lstm-0045.zip'),
'conv-lite-fc': (27, root_url + '/conv-lite-fc-0027.zip'),
'densenet-lite-lstm': (42, root_url + '/densenet-lite-lstm-0042.zip'),
'conv-lstm': (50, root_url + '/conv-lstm.zip'),
'conv-lite-lstm': (45, root_url + '/conv-lite-lstm.zip'),
'conv-lite-fc': (27, root_url + '/conv-lite-fc.zip'),
'densenet-lite-lstm': (42, root_url + '/densenet-lite-lstm.zip'),
'densenet-lite-fc': (32, root_url + '/densenet-lite-fc.zip'),
}
......
......@@ -71,7 +71,7 @@ def get_model_file(model_dir):
zip_file_path = model_dir + '.zip'
if not os.path.exists(zip_file_path):
model_name = os.path.basename(model_dir).rsplit('-', maxsplit=1)[0]
model_name = os.path.basename(model_dir)
if model_name not in AVAILABLE_MODELS:
raise NotImplementedError('%s is not an available downloaded model' % model_name)
url = AVAILABLE_MODELS[model_name][1]
......
......@@ -39,41 +39,32 @@ from cnocr.fit.fit import fit
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
default_model_prefix = os.path.join(
data_dir(), 'models', 'cnocr-v{}'.format(__version__)
)
parser.add_argument(
"--emb_model_type",
help="which embedding model to use",
choices=EMB_MODEL_TYPES,
type=str,
default='conv-rnn',
default='conv-lite',
)
parser.add_argument(
"--seq_model_type",
help='which sequence model to use',
default='lstm',
default='fc',
type=str,
choices=SEQ_MODEL_TYPES,
)
parser.add_argument(
"--data_root",
help="Path to image files",
type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator',
)
parser.add_argument(
"--train_file",
help="Path to train txt file",
type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/train.txt',
default='data/sample-data-lst/train.txt',
)
parser.add_argument(
"--test_file",
help="Path to test txt file",
type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/test.txt',
default='data/sample-data-lst/test.txt',
)
parser.add_argument(
"--use_train_image_aug",
......@@ -81,7 +72,10 @@ def parse_args():
help="Whether to use image augmentation for training",
)
parser.add_argument(
"--gpu", help="Number of GPUs for training [Default 0, means using cpu]", type=int, default=0
"--gpu",
help="Number of GPUs for training [Default 0, means using cpu]",
type=int,
default=0,
)
parser.add_argument(
"--optimizer",
......@@ -97,12 +91,7 @@ def parse_args():
type=int,
help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]',
)
parser.add_argument(
'--lr',
type=float,
default=0.001,
help='learning rate',
)
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument(
'--wd', type=float, default=0.0, help='weight decay factor [Default: 0.0]'
)
......@@ -113,9 +102,9 @@ def parse_args():
help='value for clip gradient [Default: None, means no gradient will be clip]',
)
parser.add_argument(
"--prefix",
help="Checkpoint prefix [Default '{}']".format(default_model_prefix),
default=default_model_prefix,
"--out_model_dir",
help='output model directory',
default=os.path.join(data_dir(), __version__),
)
return parser.parse_args()
......@@ -124,7 +113,13 @@ def train_cnocr(args):
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
args.model_name = args.emb_model_type + '-' + args.seq_model_type
args.prefix = '{}-{}'.format(args.prefix, args.model_name)
out_dir = os.path.join(args.out_model_dir, args.model_name)
print('save models to dir: %s' % out_dir, flush=True)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
args.prefix = os.path.join(
out_dir, 'cnocr-v{}-{}'.format(__version__, args.model_name)
)
hp = CnHyperparams()
hp = _update_hp(hp, args)
......
......@@ -152,6 +152,8 @@ def make_list_new(args):
prefix = ''
else:
working_dir = os.path.dirname(args.prefix)
if not os.path.exists(working_dir):
os.makedirs(working_dir)
prefix = os.path.basename(args.prefix)
test_list = read_file(args.test_idx_fp)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册