由于Scheduled Sampling是对Sequence to Sequence模型的改进,其整体实现框架与Sequence to Sequence模型较为相似。为突出本文重点,这里仅介绍与Scheduled Sampling相关的部分,完整的代码见`scheduled_sampling.py`。
首先定义控制衰减概率的类`RandomScheduleGenerator`,如下:
```python
importnumpyasnp
importmath
classRandomScheduleGenerator:
"""
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
"""
def__init__(self,schedule_type,a,b):
"""
schduled_type: is the type of the decay. It supports constant, linear,