提交 b1ab60da 编写于 作者: W wwhu

adjust some comments

上级 bb93d5c0
import numpy as np import numpy as np
import math import math
import pdb
'''
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
'''
class RandomScheduleGenerator: class RandomScheduleGenerator:
''' """
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
"""
def __init__(self, schedule_type, a, b):
"""
schduled_type: is the type of the decay. It supports constant, linear, schduled_type: is the type of the decay. It supports constant, linear,
exponential, and inverse_sigmoid right now. exponential, and inverse_sigmoid right now.
a: parameter of the decay (MUST BE DOUBLE) a: parameter of the decay (MUST BE DOUBLE)
b: parameter of the decay (MUST BE DOUBLE) b: parameter of the decay (MUST BE DOUBLE)
''' """
def __init__(self, schedule_type, a, b):
self.schedule_type = schedule_type self.schedule_type = schedule_type
self.a = a self.a = a
self.b = b self.b = b
...@@ -24,33 +23,25 @@ class RandomScheduleGenerator: ...@@ -24,33 +23,25 @@ class RandomScheduleGenerator:
"constant": lambda a, b, d: a, "constant": lambda a, b, d: a,
"linear": lambda a, b, d: max(a, 1 - d / b), "linear": lambda a, b, d: max(a, 1 - d / b),
"exponential": lambda a, b, d: pow(a, d / b), "exponential": lambda a, b, d: pow(a, d / b),
"inverse_sigmoid": lambda a, b, d: b / (b + exp(d * a / b)), "inverse_sigmoid": lambda a, b, d: b / (b + math.exp(d * a / b)),
} }
assert (self.schedule_type in self.schedule_computers) assert (self.schedule_type in self.schedule_computers)
self.schedule_computer = self.schedule_computers[self.schedule_type] self.schedule_computer = self.schedule_computers[self.schedule_type]
'''
Get the schedule sampling rate. Usually not needed to be called by the users
'''
def getScheduleRate(self): def getScheduleRate(self):
"""
Get the schedule sampling rate. Usually not needed to be called by the users
"""
return self.schedule_computer(self.a, self.b, self.data_processed_) return self.schedule_computer(self.a, self.b, self.data_processed_)
''' def processBatch(self, batch_size):
"""
Get a batch_size of sampled indexes. These indexes can be passed to a Get a batch_size of sampled indexes. These indexes can be passed to a
MultiplexLayer to select from the grouth truth and generated samples MultiplexLayer to select from the grouth truth and generated samples
from the last time step. from the last time step.
''' """
def processBatch(self, batch_size):
rate = self.getScheduleRate() rate = self.getScheduleRate()
numbers = np.random.rand(batch_size) numbers = np.random.rand(batch_size)
indexes = (numbers >= rate).astype('int32').tolist() indexes = (numbers >= rate).astype('int32').tolist()
self.data_processed_ += batch_size self.data_processed_ += batch_size
return indexes return indexes
if __name__ == "__main__":
schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000)
true_token_flag = schedule_generator.processBatch(5)
pdb.set_trace()
...@@ -74,7 +74,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ...@@ -74,7 +74,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
decoder_state=decoder_mem) decoder_state=decoder_mem)
gru_out_memory = paddle.layer.memory( gru_out_memory = paddle.layer.memory(
name='gru_out', size=target_dict_dim) # , boot_with_const_id=0) name='gru_out', size=target_dict_dim)
generated_word = paddle.layer.max_id(input=gru_out_memory) generated_word = paddle.layer.max_id(input=gru_out_memory)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册