提交 3e2cc77b 编写于 作者: B breezedeus

add lr and optimizer params

上级 65f55c97
......@@ -80,7 +80,13 @@ def parse_args():
default=2,
)
parser.add_argument(
"--gpu", help="Number of GPUs for training [Default 0]", type=int
"--gpu", help="Number of GPUs for training [Default 0]", type=int, default=0
)
parser.add_argument(
"--optimizer",
help="optimizer for training [Default: Adam]",
type=str,
default='Adam',
)
parser.add_argument(
'--epoch', type=int, default=20, help='train epochs [Default: 20]'
......@@ -90,6 +96,12 @@ def parse_args():
type=int,
help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]',
)
parser.add_argument(
'--lr',
type=float,
default=None,
help='learning rate [Default: None, means lr from hp will be used]',
)
parser.add_argument(
"--prefix",
help="Checkpoint prefix [Default '{}']".format(default_model_prefix),
......@@ -98,36 +110,16 @@ def parse_args():
parser.add_argument(
"--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc'
)
parser.add_argument(
"--num_proc",
help="Number CAPTCHA generating processes [Default 4]",
type=int,
default=4,
)
parser.add_argument(
"--font_path", help="Path to ttf font file or directory containing ttf files"
)
return parser.parse_args()
def get_fonts(path):
fonts = list()
if os.path.isdir(path):
for filename in os.listdir(path):
if filename.endswith('.ttf') or filename.endswith('.ttc'):
fonts.append(os.path.join(path, filename))
else:
fonts.append(path)
return fonts
def run_cn_ocr(args):
def train_cnocr(args):
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
args.prefix = '{}-{}'.format(args.prefix, args.model_name)
hp = CnHyperparams()
hp._num_epoch = args.epoch
hp = _update_hp(hp, args)
network, hp = gen_network(args.model_name, hp)
metrics = CtcMetrics(hp.seq_length)
......@@ -147,6 +139,14 @@ def run_cn_ocr(args):
)
def _update_hp(hp, args):
hp._num_epoch = args.epoch
hp.optimizer = args.optimizer
if args.lr is not None:
hp._learning_rate = args.lr
return hp
def _gen_iters(hp, train_fp_prefix, val_fp_prefix, use_train_image_aug):
height, width = hp.img_height, hp.img_width
augs = None
......@@ -192,4 +192,4 @@ def _gen_iters(hp, train_fp_prefix, val_fp_prefix, use_train_image_aug):
if __name__ == '__main__':
args = parse_args()
run_cn_ocr(args)
train_cnocr(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册