提交 5ae580c8 编写于 作者: Z zhxfl

add trans delay

上级 a4d00a37
......@@ -8,6 +8,7 @@ import numpy as np
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.augmentor.trans_delay as trans_delay
class TestTransMeanVarianceNorm(unittest.TestCase):
......@@ -112,5 +113,24 @@ class TestTransSplict(unittest.TestCase):
self.assertAlmostEqual(feature[i][j * 10 + k], cur_val)
class TestTransDelay(unittest.TestCase):
"""unittest TransDelay
"""
def test_perform(self):
label = np.zeros((10, 1), dtype="int64")
for i in xrange(10):
label[i][0] = i
trans = trans_delay.TransDelay(5)
(_, label, _) = trans.perform_trans((None, label, None))
for i in xrange(5):
self.assertAlmostEqual(label[i + 5][0], i)
for i in xrange(5):
self.assertAlmostEqual(label[i][0], 0)
if __name__ == '__main__':
unittest.main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
class TransDelay(object):
""" Delay label, and copy first label value in the front.
Attributes:
_delay_time : the delay frame num of label
"""
def __init__(self, delay_time):
"""init construction
Args:
delay_time : the delay frame num of label
"""
self._delay_time = delay_time
def perform_trans(self, sample):
"""
Args:
sample(object):input sample, contain feature numpy and label numpy, sample name list
Returns:
(feature, label, name)
"""
(feature, label, name) = sample
shape = label.shape
assert len(shape) == 2
label[self._delay_time:shape[0]] = label[0:shape[0] - self._delay_time]
for i in xrange(self._delay_time):
label[i][0] = label[self._delay_time][0]
return (feature, label, name)
......@@ -12,6 +12,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.augmentor.trans_delay as trans_delay
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray
from model_utils.model import stacked_lstmp_model
......@@ -171,7 +172,7 @@ def train(args):
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice()
trans_splice.TransSplice(), trans_delay.TransDelay(5)
]
feature_t = fluid.LoDTensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册