未验证 提交 b1c37965 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #635 from pkuyym/fix-630

Change to parallel reader
...@@ -9,9 +9,9 @@ import time ...@@ -9,9 +9,9 @@ import time
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler import paddle.v2.fluid.profiler as profiler
import data_utils.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.trans_splice as trans_splice import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader import data_utils.data_reader as reader
...@@ -22,6 +22,12 @@ def parse_args(): ...@@ -22,6 +22,12 @@ def parse_args():
type=int, type=int,
default=32, default=32,
help='The sequence number of a batch data. (default: %(default)d)') help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--minimum_batch_size',
type=int,
default=1,
help='The minimum sequence number of a batch data. (default: %(default)d)'
)
parser.add_argument( parser.add_argument(
'--stacked_num', '--stacked_num',
type=int, type=int,
...@@ -160,14 +166,15 @@ def train(args): ...@@ -160,14 +166,15 @@ def train(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# @TODO datareader should take the responsibility (parsing from config file)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice() trans_splice.TransSplice()
] ]
data_reader = reader.DataRead(args.feature_lst, args.label_lst) data_reader = reader.DataReader(args.feature_lst, args.label_lst)
data_reader.set_trans(ltrans) data_reader.set_transformers(ltrans)
res_feature = fluid.LoDTensor() res_feature = fluid.LoDTensor()
res_label = fluid.LoDTensor() res_label = fluid.LoDTensor()
...@@ -175,22 +182,15 @@ def train(args): ...@@ -175,22 +182,15 @@ def train(args):
pass_start_time = time.time() pass_start_time = time.time()
words_seen = 0 words_seen = 0
accuracy.reset(exe) accuracy.reset(exe)
batch_id = 0 for batch_id, batch_data in enumerate(
while True: data_reader.batch_iterator(args.batch_size,
# load_data args.minimum_batch_size)):
one_batch = data_reader.get_one_batch(args.batch_size) (bat_feature, bat_label, lod) = batch_data
if one_batch == None:
break
(bat_feature, bat_label, lod) = one_batch
res_feature.set(bat_feature, place) res_feature.set(bat_feature, place)
res_feature.set_lod([lod]) res_feature.set_lod([lod])
res_label.set(bat_label, place) res_label.set(bat_label, place)
res_label.set_lod([lod]) res_label.set_lod([lod])
batch_id += 1
words_seen += lod[-1] words_seen += lod[-1]
loss, acc = exe.run( loss, acc = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={"feature": res_feature, feed={"feature": res_feature,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册