提交 b1ab60da 编写于 作者: W wwhu

adjust some comments

上级 bb93d5c0
import numpy as np
import math
import pdb
'''
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
'''
class RandomScheduleGenerator:
'''
schduled_type: is the type of the decay. It supports constant, linear,
exponential, and inverse_sigmoid right now.
a: parameter of the decay (MUST BE DOUBLE)
b: parameter of the decay (MUST BE DOUBLE)
'''
"""
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
"""
def __init__(self, schedule_type, a, b):
"""
schduled_type: is the type of the decay. It supports constant, linear,
exponential, and inverse_sigmoid right now.
a: parameter of the decay (MUST BE DOUBLE)
b: parameter of the decay (MUST BE DOUBLE)
"""
self.schedule_type = schedule_type
self.a = a
self.b = b
......@@ -24,33 +23,25 @@ class RandomScheduleGenerator:
"constant": lambda a, b, d: a,
"linear": lambda a, b, d: max(a, 1 - d / b),
"exponential": lambda a, b, d: pow(a, d / b),
"inverse_sigmoid": lambda a, b, d: b / (b + exp(d * a / b)),
"inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
}
assert (self.schedule_type in self.schedule_computers)
self.schedule_computer = self.schedule_computers[self.schedule_type]
'''
Get the schedule sampling rate. Usually not needed to be called by the users
'''
def getScheduleRate(self):
"""
Get the schedule sampling rate. Usually not needed to be called by the users
"""
return self.schedule_computer(self.a, self.b, self.data_processed_)
'''
Get a batch_size of sampled indexes. These indexes can be passed to a
MultiplexLayer to select from the grouth truth and generated samples
from the last time step.
'''
def processBatch(self, batch_size):
"""
Get a batch_size of sampled indexes. These indexes can be passed to a
MultiplexLayer to select from the grouth truth and generated samples
from the last time step.
"""
rate = self.getScheduleRate()
numbers = np.random.rand(batch_size)
indexes = (numbers >= rate).astype('int32').tolist()
self.data_processed_ += batch_size
return indexes
if __name__ == "__main__":
schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000)
true_token_flag = schedule_generator.processBatch(5)
pdb.set_trace()
......@@ -74,7 +74,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
decoder_state=decoder_mem)
gru_out_memory = paddle.layer.memory(
name='gru_out', size=target_dict_dim) # , boot_with_const_id=0)
name='gru_out', size=target_dict_dim)
generated_word = paddle.layer.max_id(input=gru_out_memory)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册