trans_mean_variance_norm.py 1.7 KB
Newer Older
Z
zhxfl 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
#by zhxfl 2018.01.29
import numpy
import math


class TransMeanVarianceNorm(object):
    """ normalization of mean variance for feature data 
    """

    def __init__(self, snorm_path):
        """init construction
            Args:
                snorm_path: the path of mean and variance
        """
        self._mean = None
        self._var = None
        self._load_norm(snorm_path)

    def _load_norm(self, snorm_path):
        """ load global mean var file
        """
        lLines = open(snorm_path).readlines()
        nLen = len(lLines)
        self._mean = numpy.zeros((nLen), dtype="float32")
        self._var = numpy.zeros((nLen), dtype="float32")
        self._nLen = nLen
        for nidx, l in enumerate(lLines):
            s = l.split()
            assert len(s) == 2
            self._mean[nidx] = float(s[0])
            self._var[nidx] = 1.0 / math.sqrt(float(s[1]))
            if self._var[nidx] > 100000.0:
                self._var[nidx] = 100000.0

    def get_mean_var(self):
        """ get mean and var 
        """
        return (self._mean, self._var)

    def perform_trans(self, sample):
        """ feature = (feature - mean) * var
        """
        (feature, label) = sample
        shape = feature.shape
        assert len(shape) == 2
        nfeature_len = shape[0] * shape[1]
        assert nfeature_len % self._nLen == 0
        ncur_idx = 0
        feature = feature.reshape((nfeature_len))
        while ncur_idx < nfeature_len:
            block = feature[ncur_idx:ncur_idx + self._nLen]
            block = (block - self._mean) * self._var
            feature[ncur_idx:ncur_idx + self._nLen] = block
            ncur_idx += self._nLen
        feature = feature.reshape(shape)
        return (feature, label)