From 5ae580c871fd3d520ce8e14034007ff9ddeafe8a Mon Sep 17 00:00:00 2001 From: zhxfl <291221622@qq.com> Date: Thu, 5 Apr 2018 22:02:51 +0800 Subject: [PATCH] add trans delay --- .../augmentor/tests/test_data_trans.py | 20 ++++++++++ .../data_utils/augmentor/trans_delay.py | 37 +++++++++++++++++++ fluid/DeepASR/train.py | 3 +- 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 fluid/DeepASR/data_utils/augmentor/trans_delay.py diff --git a/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py b/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py index 9f76a9f8..6b18f3fa 100644 --- a/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py +++ b/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py @@ -8,6 +8,7 @@ import numpy as np import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice +import data_utils.augmentor.trans_delay as trans_delay class TestTransMeanVarianceNorm(unittest.TestCase): @@ -112,5 +113,24 @@ class TestTransSplict(unittest.TestCase): self.assertAlmostEqual(feature[i][j * 10 + k], cur_val) +class TestTransDelay(unittest.TestCase): + """unittest TransDelay + """ + + def test_perform(self): + label = np.zeros((10, 1), dtype="int64") + for i in xrange(10): + label[i][0] = i + + trans = trans_delay.TransDelay(5) + (_, label, _) = trans.perform_trans((None, label, None)) + + for i in xrange(5): + self.assertAlmostEqual(label[i + 5][0], i) + + for i in xrange(5): + self.assertAlmostEqual(label[i][0], 0) + + if __name__ == '__main__': unittest.main() diff --git a/fluid/DeepASR/data_utils/augmentor/trans_delay.py b/fluid/DeepASR/data_utils/augmentor/trans_delay.py new file mode 100644 index 00000000..b782498e --- /dev/null +++ b/fluid/DeepASR/data_utils/augmentor/trans_delay.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import math + + +class TransDelay(object): + """ Delay label, and copy first label value in the front. + Attributes: + _delay_time : the delay frame num of label + """ + + def __init__(self, delay_time): + """init construction + Args: + delay_time : the delay frame num of label + """ + self._delay_time = delay_time + + def perform_trans(self, sample): + """ + Args: + sample(object):input sample, contain feature numpy and label numpy, sample name list + Returns: + (feature, label, name) + """ + (feature, label, name) = sample + + shape = label.shape + assert len(shape) == 2 + label[self._delay_time:shape[0]] = label[0:shape[0] - self._delay_time] + for i in xrange(self._delay_time): + label[i][0] = label[self._delay_time][0] + + return (feature, label, name) diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index be99998c..50ca0470 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -12,6 +12,7 @@ import paddle.fluid as fluid import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice +import data_utils.augmentor.trans_delay as trans_delay import data_utils.async_data_reader as reader from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model @@ -171,7 +172,7 @@ def train(args): ltrans = [ trans_add_delta.TransAddDelta(2, 2), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), - trans_splice.TransSplice() + trans_splice.TransSplice(), trans_delay.TransDelay(5) ] feature_t = fluid.LoDTensor() -- GitLab