diff --git a/fluid/DeepASR/.gitignore b/fluid/DeepASR/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..485dee64bcfb48793379b200a1afd14e85a8aaf4 --- /dev/null +++ b/fluid/DeepASR/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/fluid/DeepASR/data_utils/util.py b/fluid/DeepASR/data_utils/util.py index 27f5ba8305c0e72852230658858d77fed9d233a4..4a5a8a3f1dad1c46ed773fd48d713e276717d5e5 100644 --- a/fluid/DeepASR/data_utils/util.py +++ b/fluid/DeepASR/data_utils/util.py @@ -25,16 +25,6 @@ def to_lodtensor(data, place): return res -def lodtensor_to_ndarray(lod_tensor): - """conver lodtensor to ndarray - """ - dims = lod_tensor._get_dims() - ret = np.zeros(shape=dims).astype('float32') - for i in xrange(np.product(dims)): - ret.ravel()[i] = lod_tensor.get_float_element(i) - return ret, lod_tensor.lod() - - def split_infer_result(infer_seq, lod): infer_batch = [] for i in xrange(0, len(lod[0]) - 1): diff --git a/fluid/DeepASR/decoder/.gitignore b/fluid/DeepASR/decoder/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef5c97cfb5c06f3308980ca65c87e9c4b9440171 --- /dev/null +++ b/fluid/DeepASR/decoder/.gitignore @@ -0,0 +1,4 @@ +ThreadPool +build +post_latgen_faster_mapped.so +pybind11 diff --git a/fluid/DeepASR/examples/aishell/.gitignore b/fluid/DeepASR/examples/aishell/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c173dd880ae9e06c16989800e06d4d3d7a1a7d5f --- /dev/null +++ b/fluid/DeepASR/examples/aishell/.gitignore @@ -0,0 +1,4 @@ +aux.tar.gz +aux +data +checkpoints diff --git a/fluid/DeepASR/examples/aishell/train.sh b/fluid/DeepASR/examples/aishell/train.sh index 06fe488d4572782d946e8daa7c22ded8ef0212c6..168581c0ee579ef62f138bb0d8f5bb8886beb90b 100644 --- a/fluid/DeepASR/examples/aishell/train.sh +++ b/fluid/DeepASR/examples/aishell/train.sh @@ -1,5 +1,5 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7 -python -u ../../train.py --train_feature_lst data/train_feature.lst \ +python -u ../../train.py --train_feature_lst data/train_feature.lst \ --train_label_lst data/train_label.lst \ --val_feature_lst data/val_feature.lst \ --val_label_lst data/val_label.lst \ @@ -7,7 +7,8 @@ python -u ../../train.py --train_feature_lst data/train_feature.lst \ --checkpoints checkpoints \ --frame_dim 80 \ --class_num 3040 \ + --print_per_batches 100 \ --infer_models '' \ - --batch_size 64 \ + --batch_size 16 \ --learning_rate 6.4e-5 \ --parallel diff --git a/fluid/DeepASR/model_utils/model.py b/fluid/DeepASR/model_utils/model.py index 31892a7ad93eed0ba6ad6e7b53377e897f6df29d..0b086b55a898a0a29f57132b438684a655e30caf 100644 --- a/fluid/DeepASR/model_utils/model.py +++ b/fluid/DeepASR/model_utils/model.py @@ -5,19 +5,21 @@ from __future__ import print_function import paddle.fluid as fluid -def stacked_lstmp_model(frame_dim, +def stacked_lstmp_model(feature, + label, hidden_dim, proj_dim, stacked_num, class_num, parallel=False, is_train=True): - """ The model for DeepASR. The main structure is composed of stacked - identical LSTMP (LSTM with recurrent projection) layers. + """ + The model for DeepASR. The main structure is composed of stacked + identical LSTMP (LSTM with recurrent projection) layers. - When running in training and validation phase, the feeding dictionary - is {'feature', 'label'}, fed by the LodTensor for feature data and - label data respectively. And in inference, only `feature` is needed. + When running in training and validation phase, the feeding dictionary + is {'feature', 'label'}, fed by the LodTensor for feature data and + label data respectively. And in inference, only `feature` is needed. Args: frame_dim(int): The frame dimension of feature data. @@ -28,80 +30,45 @@ def stacked_lstmp_model(frame_dim, is_train(bool): Run in training phase or not, default `True`. class_dim(int): The number of output classes. """ + conv1 = fluid.layers.conv2d( + input=feature, + num_filters=32, + filter_size=3, + stride=1, + padding=1, + bias_attr=True, + act="relu") - # network configuration - def _net_conf(feature, label): - conv1 = fluid.layers.conv2d( - input=feature, - num_filters=32, - filter_size=3, - stride=1, - padding=1, - bias_attr=True, - act="relu") - - pool1 = fluid.layers.pool2d( - conv1, pool_size=3, pool_type="max", pool_stride=2, pool_padding=0) - - stack_input = pool1 - for i in range(stacked_num): - fc = fluid.layers.fc(input=stack_input, - size=hidden_dim * 4, - bias_attr=None) - proj, cell = fluid.layers.dynamic_lstmp( - input=fc, - size=hidden_dim * 4, - proj_size=proj_dim, - bias_attr=True, - use_peepholes=True, - is_reverse=False, - cell_activation="tanh", - proj_activation="tanh") - bn = fluid.layers.batch_norm( - input=proj, - is_test=not is_train, - momentum=0.9, - epsilon=1e-05, - data_layout='NCHW') - stack_input = bn - - prediction = fluid.layers.fc(input=stack_input, - size=class_num, - act='softmax') + pool1 = fluid.layers.pool2d( + conv1, pool_size=3, pool_type="max", pool_stride=2, pool_padding=0) - cost = fluid.layers.cross_entropy(input=prediction, label=label) - avg_cost = fluid.layers.mean(x=cost) - acc = fluid.layers.accuracy(input=prediction, label=label) - return prediction, avg_cost, acc - - # data feeder - feature = fluid.layers.data( - name="feature", - shape=[-1, 3, 11, frame_dim], - dtype="float32", - lod_level=1) - label = fluid.layers.data( - name="label", shape=[-1, 1], dtype="int64", lod_level=1) - - if parallel: - # When the execution place is specified to CUDAPlace, the program will - # run on all $CUDA_VISIBLE_DEVICES GPUs. Otherwise the program will - # run on all CPU devices. - places = fluid.layers.device.get_places() - pd = fluid.layers.ParallelDo(places) - with pd.do(): - feat_ = pd.read_input(feature) - label_ = pd.read_input(label) - prediction, avg_cost, acc = _net_conf(feat_, label_) - for out in [prediction, avg_cost, acc]: - pd.write_output(out) + stack_input = pool1 + for i in range(stacked_num): + fc = fluid.layers.fc(input=stack_input, + size=hidden_dim * 4, + bias_attr=None) + proj, cell = fluid.layers.dynamic_lstmp( + input=fc, + size=hidden_dim * 4, + proj_size=proj_dim, + bias_attr=True, + use_peepholes=True, + is_reverse=False, + cell_activation="tanh", + proj_activation="tanh") + bn = fluid.layers.batch_norm( + input=proj, + is_test=not is_train, + momentum=0.9, + epsilon=1e-05, + data_layout='NCHW') + stack_input = bn - # get mean loss and acc through every devices. - prediction, avg_cost, acc = pd() - prediction.stop_gradient = True - avg_cost = fluid.layers.mean(x=avg_cost) - acc = fluid.layers.mean(x=acc) - else: - prediction, avg_cost, acc = _net_conf(feature, label) + prediction = fluid.layers.fc(input=stack_input, + size=class_num, + act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) return prediction, avg_cost, acc diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 1c35a6637f534abf4a37763fe1915c35e18e1f94..1a1dd6cf9ea33bb546cc3bdf65c36be0441832cb 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -14,7 +14,6 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_delay as trans_delay import data_utils.async_data_reader as reader -from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model @@ -24,7 +23,8 @@ def parse_args(): '--batch_size', type=int, default=32, - help='The sequence number of a batch data. (default: %(default)d)') + help='The sequence number of a batch data. Batch size per GPU. (default: %(default)d)' + ) parser.add_argument( '--minimum_batch_size', type=int, @@ -147,29 +147,72 @@ def train(args): if args.infer_models != '' and not os.path.exists(args.infer_models): os.mkdir(args.infer_models) - prediction, avg_cost, accuracy = stacked_lstmp_model( - frame_dim=args.frame_dim, - hidden_dim=args.hidden_dim, - proj_dim=args.proj_dim, - stacked_num=args.stacked_num, - class_num=args.class_num, - parallel=args.parallel) - - # program for test - test_program = fluid.default_main_program().clone() - - #optimizer = fluid.optimizer.Momentum(learning_rate=args.learning_rate, momentum=0.9) - optimizer = fluid.optimizer.Adam( - learning_rate=fluid.layers.exponential_decay( - learning_rate=args.learning_rate, - decay_steps=1879, - decay_rate=1 / 1.2, - staircase=True)) - optimizer.minimize(avg_cost) + train_program = fluid.Program() + train_startup = fluid.Program() + with fluid.program_guard(train_program, train_startup): + with fluid.unique_name.guard(): + py_train_reader = fluid.layers.py_reader( + capacity=10, + shapes=([-1, 3, 11, args.frame_dim], [-1, 1]), + dtypes=['float32', 'int64'], + lod_levels=[1, 1], + name='train_reader') + feature, label = fluid.layers.read_file(py_train_reader) + prediction, avg_cost, accuracy = stacked_lstmp_model( + feature=feature, + label=label, + hidden_dim=args.hidden_dim, + proj_dim=args.proj_dim, + stacked_num=args.stacked_num, + class_num=args.class_num) + # optimizer = fluid.optimizer.Momentum(learning_rate=args.learning_rate, momentum=0.9) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.learning_rate, + decay_steps=1879, + decay_rate=1 / 1.2, + staircase=True)) + optimizer.minimize(avg_cost) + fluid.memory_optimize(train_program) + + test_program = fluid.Program() + test_startup = fluid.Program() + with fluid.program_guard(test_program, test_startup): + with fluid.unique_name.guard(): + py_test_reader = fluid.layers.py_reader( + capacity=10, + shapes=([-1, 3, 11, args.frame_dim], [-1, 1]), + dtypes=['float32', 'int64'], + lod_levels=[1, 1], + name='test_reader') + feature, label = fluid.layers.read_file(py_test_reader) + prediction, avg_cost, accuracy = stacked_lstmp_model( + feature=feature, + label=label, + hidden_dim=args.hidden_dim, + proj_dim=args.proj_dim, + stacked_num=args.stacked_num, + class_num=args.class_num) + test_program = test_program.clone(for_test=True) place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + exe.run(train_startup) + exe.run(test_startup) + + if args.parallel: + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_iteration_per_drop_scope = 10 + train_exe = fluid.ParallelExecutor( + use_cuda=(args.device == 'GPU'), + loss_name=avg_cost.name, + exec_strategy=exec_strategy, + main_program=train_program) + test_exe = fluid.ParallelExecutor( + use_cuda=(args.device == 'GPU'), + main_program=test_program, + exec_strategy=exec_strategy, + share_vars_from=train_exe) # resume training if initial model provided. if args.init_model_path is not None: @@ -181,15 +224,24 @@ def train(args): trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5) ] - feature_t = fluid.LoDTensor() - label_t = fluid.LoDTensor() + # bind train_reader + train_data_reader = reader.AsyncDataReader( + args.train_feature_lst, + args.train_label_lst, + -1, + split_sentence_threshold=1024) - # validation - def test(exe): - # If test data not found, return invalid cost and accuracy - if not (os.path.exists(args.val_feature_lst) and - os.path.exists(args.val_label_lst)): - return -1.0, -1.0 + train_data_reader.set_transformers(ltrans) + + def train_data_provider(): + for data in train_data_reader.batch_iterator(args.batch_size, + args.minimum_batch_size): + yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace()) + + py_train_reader.decorate_tensor_provider(train_data_provider) + + if (os.path.exists(args.val_feature_lst) and + os.path.exists(args.val_label_lst)): # test data reader test_data_reader = reader.AsyncDataReader( args.val_feature_lst, @@ -197,86 +249,101 @@ def train(args): -1, split_sentence_threshold=1024) test_data_reader.set_transformers(ltrans) - test_costs, test_accs = [], [] - for batch_id, batch_data in enumerate( - test_data_reader.batch_iterator(args.batch_size, - args.minimum_batch_size)): - # load_data - (features, labels, lod, _) = batch_data - features = np.reshape(features, (-1, 11, 3, args.frame_dim)) - features = np.transpose(features, (0, 2, 1, 3)) - feature_t.set(features, place) - feature_t.set_lod([lod]) - label_t.set(labels, place) - label_t.set_lod([lod]) - - cost, acc = exe.run(test_program, - feed={"feature": feature_t, - "label": label_t}, - fetch_list=[avg_cost, accuracy], - return_numpy=False) - test_costs.append(lodtensor_to_ndarray(cost)[0]) - test_accs.append(lodtensor_to_ndarray(acc)[0]) - return np.mean(test_costs), np.mean(test_accs) - # train data reader - train_data_reader = reader.AsyncDataReader( - args.train_feature_lst, - args.train_label_lst, - -1, - split_sentence_threshold=1024) + def test_data_provider(): + for data in test_data_reader.batch_iterator( + args.batch_size, args.minimum_batch_size): + yield batch_data_to_lod_tensors(args, data, fluid.CPUPlace()) + + py_test_reader.decorate_tensor_provider(test_data_provider) + + # validation + def test(exe): + # If test data not found, return invalid cost and accuracy + if not (os.path.exists(args.val_feature_lst) and + os.path.exists(args.val_label_lst)): + return -1.0, -1.0 + batch_id = 0 + test_costs = [] + test_accs = [] + while True: + if batch_id == 0: + py_test_reader.start() + try: + if args.parallel: + cost, acc = exe.run( + fetch_list=[avg_cost.name, accuracy.name], + return_numpy=False) + else: + cost, acc = exe.run(program=test_program, + fetch_list=[avg_cost, accuracy], + return_numpy=False) + sys.stdout.write('.') + sys.stdout.flush() + test_costs.append(np.array(cost)[0]) + test_accs.append(np.array(acc)[0]) + batch_id += 1 + except fluid.core.EOFException: + py_test_reader.reset() + break + return np.mean(test_costs), np.mean(test_accs) - train_data_reader.set_transformers(ltrans) # train for pass_id in xrange(args.pass_num): pass_start_time = time.time() - for batch_id, batch_data in enumerate( - train_data_reader.batch_iterator(args.batch_size, - args.minimum_batch_size)): - # load_data - (features, labels, lod, name_lst) = batch_data - features = np.reshape(features, (-1, 11, 3, args.frame_dim)) - features = np.transpose(features, (0, 2, 1, 3)) - feature_t.set(features, place) - feature_t.set_lod([lod]) - label_t.set(labels, place) - label_t.set_lod([lod]) - + batch_id = 0 + while True: + if batch_id == 0: + py_train_reader.start() to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0) - outs = exe.run(fluid.default_main_program(), - feed={"feature": feature_t, - "label": label_t}, - fetch_list=[avg_cost, accuracy] if to_print else [], - return_numpy=False) + try: + if args.parallel: + outs = train_exe.run( + fetch_list=[avg_cost.name, accuracy.name] + if to_print else [], + return_numpy=False) + else: + outs = exe.run(program=train_program, + fetch_list=[avg_cost, accuracy] + if to_print else [], + return_numpy=False) + except fluid.core.EOFException: + py_train_reader.reset() + break if to_print: - print("\nBatch %d, train cost: %f, train acc: %f" % - (batch_id, lodtensor_to_ndarray(outs[0])[0], - lodtensor_to_ndarray(outs[1])[0])) + if args.parallel: + print("\nBatch %d, train cost: %f, train acc: %f" % + (batch_id, np.mean(outs[0]), np.mean(outs[1]))) + else: + print("\nBatch %d, train cost: %f, train acc: %f" % ( + batch_id, np.array(outs[0])[0], np.array(outs[1])[0])) # save the latest checkpoint if args.checkpoints != '': model_path = os.path.join(args.checkpoints, "deep_asr.latest.checkpoint") - fluid.io.save_persistables(exe, model_path) + fluid.io.save_persistables(exe, model_path, train_program) else: sys.stdout.write('.') sys.stdout.flush() + + batch_id += 1 # run test - val_cost, val_acc = test(exe) + val_cost, val_acc = test(test_exe if args.parallel else exe) # save checkpoint per pass if args.checkpoints != '': model_path = os.path.join( args.checkpoints, "deep_asr.pass_" + str(pass_id) + ".checkpoint") - fluid.io.save_persistables(exe, model_path) + fluid.io.save_persistables(exe, model_path, train_program) # save inference model if args.infer_models != '': model_path = os.path.join( args.infer_models, "deep_asr.pass_" + str(pass_id) + ".infer.model") fluid.io.save_inference_model(model_path, ["feature"], - [prediction], exe) + [prediction], exe, train_program) # cal pass time pass_end_time = time.time() time_consumed = pass_end_time - pass_start_time @@ -285,6 +352,19 @@ def train(args): (pass_id, time_consumed, val_cost, val_acc)) +def batch_data_to_lod_tensors(args, batch_data, place): + features, labels, lod, name_lst = batch_data + features = np.reshape(features, (-1, 11, 3, args.frame_dim)) + features = np.transpose(features, (0, 2, 1, 3)) + feature_t = fluid.LoDTensor() + label_t = fluid.LoDTensor() + feature_t.set(features, place) + feature_t.set_lod([lod]) + label_t.set(labels, place) + label_t.set_lod([lod]) + return feature_t, label_t + + if __name__ == '__main__': args = parse_args() print_arguments(args)