utils.py 1.7 KB
Newer Older
W
wwhu 已提交
1
import math
2
import numpy as np
W
wwhu 已提交
3 4 5


class RandomScheduleGenerator:
W
wwhu 已提交
6
    """
7
    The random sampling rate for scheduled sampling algoithm, which uses decayed
W
wwhu 已提交
8 9
    sampling rate.
    """
W
wwhu 已提交
10 11

    def __init__(self, schedule_type, a, b):
W
wwhu 已提交
12 13 14 15 16 17
        """
        schduled_type: is the type of the decay. It supports constant, linear,
        exponential, and inverse_sigmoid right now.
        a: parameter of the decay (MUST BE DOUBLE)
        b: parameter of the decay (MUST BE DOUBLE)
        """
W
wwhu 已提交
18 19 20 21 22 23 24 25
        self.schedule_type = schedule_type
        self.a = a
        self.b = b
        self.data_processed_ = 0
        self.schedule_computers = {
            "constant": lambda a, b, d: a,
            "linear": lambda a, b, d: max(a, 1 - d / b),
            "exponential": lambda a, b, d: pow(a, d / b),
W
wwhu 已提交
26
            "inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
W
wwhu 已提交
27 28 29 30 31
        }
        assert (self.schedule_type in self.schedule_computers)
        self.schedule_computer = self.schedule_computers[self.schedule_type]

    def getScheduleRate(self):
W
wwhu 已提交
32
        """
C
caoying03 已提交
33 34
        Get the schedule sampling rate. Usually not needed to be
        called by the users.
W
wwhu 已提交
35
        """
W
wwhu 已提交
36 37 38
        return self.schedule_computer(self.a, self.b, self.data_processed_)

    def processBatch(self, batch_size):
W
wwhu 已提交
39 40 41 42 43
        """
        Get a batch_size of sampled indexes. These indexes can be passed to a
        MultiplexLayer to select from the grouth truth and generated samples
        from the last time step.
        """
W
wwhu 已提交
44 45 46 47 48
        rate = self.getScheduleRate()
        numbers = np.random.rand(batch_size)
        indexes = (numbers >= rate).astype('int32').tolist()
        self.data_processed_ += batch_size
        return indexes