random_schedule_generator.py 1.7 KB
Newer Older
W
wwhu 已提交
1 2 3 4 5
import numpy as np
import math


class RandomScheduleGenerator:
W
wwhu 已提交
6 7 8 9
    """
    The random sampling rate for scheduled sampling algoithm, which uses devcayed
    sampling rate.
    """
W
wwhu 已提交
10 11

    def __init__(self, schedule_type, a, b):
W
wwhu 已提交
12 13 14 15 16 17
        """
        schduled_type: is the type of the decay. It supports constant, linear,
        exponential, and inverse_sigmoid right now.
        a: parameter of the decay (MUST BE DOUBLE)
        b: parameter of the decay (MUST BE DOUBLE)
        """
W
wwhu 已提交
18 19 20 21 22 23 24 25
        self.schedule_type = schedule_type
        self.a = a
        self.b = b
        self.data_processed_ = 0
        self.schedule_computers = {
            "constant": lambda a, b, d: a,
            "linear": lambda a, b, d: max(a, 1 - d / b),
            "exponential": lambda a, b, d: pow(a, d / b),
W
wwhu 已提交
26
            "inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
W
wwhu 已提交
27 28 29 30 31
        }
        assert (self.schedule_type in self.schedule_computers)
        self.schedule_computer = self.schedule_computers[self.schedule_type]

    def getScheduleRate(self):
W
wwhu 已提交
32 33 34
        """
        Get the schedule sampling rate. Usually not needed to be called by the users
        """
W
wwhu 已提交
35 36 37
        return self.schedule_computer(self.a, self.b, self.data_processed_)

    def processBatch(self, batch_size):
W
wwhu 已提交
38 39 40 41 42
        """
        Get a batch_size of sampled indexes. These indexes can be passed to a
        MultiplexLayer to select from the grouth truth and generated samples
        from the last time step.
        """
W
wwhu 已提交
43 44 45 46 47
        rate = self.getScheduleRate()
        numbers = np.random.rand(batch_size)
        indexes = (numbers >= rate).astype('int32').tolist()
        self.data_processed_ += batch_size
        return indexes