提交 a62a6f92 编写于 作者: Y Yibing Liu

Add validation at the end of each training pass

上级 2f5debd7
......@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import sys
import os
import numpy as np
import argparse
import time
......@@ -75,15 +76,25 @@ def parse_args():
default='data/global_mean_var_search26kHr',
help='mean var path')
parser.add_argument(
'--feature_lst',
'--train_feature_lst',
type=str,
default='data/feature.lst',
help='feature list path.')
help='feature list path for training.')
parser.add_argument(
'--label_lst',
'--train_label_lst',
type=str,
default='data/label.lst',
help='label list path.')
help='label list path for training.')
parser.add_argument(
'--val_feature_lst',
type=str,
default='data/val_feature.lst',
help='feature list path for validation.')
parser.add_argument(
'--val_label_lst',
type=str,
default='data/val_label.lst',
help='label list path for validation.')
args = parser.parse_args()
return args
......@@ -104,6 +115,11 @@ def train(args):
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
adam_optimizer.minimize(avg_cost)
# program for test
test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost, accuracy])
place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -114,16 +130,49 @@ def train(args):
trans_splice.TransSplice()
]
data_reader = reader.DataReader(args.feature_lst, args.label_lst)
data_reader.set_transformers(ltrans)
res_feature = fluid.LoDTensor()
res_label = fluid.LoDTensor()
# validation
def test(exe):
# If test data not found, return invalid cost and accuracy
if not (os.path.exists(args.val_feature_lst) and
os.path.exists(args.val_label_lst)):
return -1.0, -1.0
# test data reader
test_data_reader = reader.DataReader(args.val_feature_lst,
args.val_label_lst)
test_data_reader.set_transformers(ltrans)
test_costs, test_accs = [], []
for batch_id, batch_data in enumerate(
test_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
# load_data
(bat_feature, bat_label, lod) = batch_data
res_feature.set(bat_feature, place)
res_feature.set_lod([lod])
res_label.set(bat_label, place)
res_label.set_lod([lod])
cost, acc = exe.run(
test_program,
feed={"feature": res_feature,
"label": res_label},
fetch_list=[avg_cost, accuracy],
return_numpy=False)
test_costs.append(lodtensor_to_ndarray(cost)[0])
test_accs.append(lodtensor_to_ndarray(acc)[0])
return np.mean(test_costs), np.mean(test_accs)
train_data_reader = reader.DataReader(args.train_feature_lst,
args.train_label_lst)
train_data_reader.set_transformers(ltrans)
# train
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
for batch_id, batch_data in enumerate(
data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
train_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
# load_data
(bat_feature, bat_label, lod) = batch_data
res_feature.set(bat_feature, place)
......@@ -144,11 +193,12 @@ def train(args):
sys.stdout.write('.')
sys.stdout.flush()
val_cost, val_acc = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
# need to add test logic (kuke)
print("\nPass %d, time consumed: %fs, test accuracy: 0.0f\n" %
(pass_id, time_consumed))
print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" %
(pass_id, time_consumed, val_cost, val_acc))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册