未验证 提交 b1c37965 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #635 from pkuyym/fix-630

Change to parallel reader
......@@ -9,9 +9,9 @@ import time
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler
import data_utils.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.trans_add_delta as trans_add_delta
import data_utils.trans_splice as trans_splice
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader
......@@ -22,6 +22,12 @@ def parse_args():
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--minimum_batch_size',
type=int,
default=1,
help='The minimum sequence number of a batch data. (default: %(default)d)'
)
parser.add_argument(
'--stacked_num',
type=int,
......@@ -160,14 +166,15 @@ def train(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# @TODO datareader should take the responsibility (parsing from config file)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice()
]
data_reader = reader.DataRead(args.feature_lst, args.label_lst)
data_reader.set_trans(ltrans)
data_reader = reader.DataReader(args.feature_lst, args.label_lst)
data_reader.set_transformers(ltrans)
res_feature = fluid.LoDTensor()
res_label = fluid.LoDTensor()
......@@ -175,22 +182,15 @@ def train(args):
pass_start_time = time.time()
words_seen = 0
accuracy.reset(exe)
batch_id = 0
while True:
# load_data
one_batch = data_reader.get_one_batch(args.batch_size)
if one_batch == None:
break
(bat_feature, bat_label, lod) = one_batch
for batch_id, batch_data in enumerate(
data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
(bat_feature, bat_label, lod) = batch_data
res_feature.set(bat_feature, place)
res_feature.set_lod([lod])
res_label.set(bat_label, place)
res_label.set_lod([lod])
batch_id += 1
words_seen += lod[-1]
loss, acc = exe.run(
fluid.default_main_program(),
feed={"feature": res_feature,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册