未验证 提交 3b7a4161 编写于 作者: W Wei Tang 提交者: GitHub

Update train.py

上级 b6ad4a3d
......@@ -54,55 +54,5 @@ def main():
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
def main2():
args = parse_args()
hp = Hyperparams()
if args.gpu:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = ImageIterLstm(
args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
data_val = ImageIterLstm(
args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
symbol = crnn_lstm(hp)
module = mx.mod.Module(
symbol,
data_names=data_names,
label_names=['label'],
context=contexts)
module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label)
metrics = CtcMetrics(hp.seq_length)
module.fit(train_data=data_train,
eval_data=data_val,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch=hp.num_epoch,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
if __name__ == '__main__':
main()
\ No newline at end of file
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册