diff --git a/paddlepalm/reader/base_reader.py b/paddlepalm/reader/base_reader.py index b35e1d158e9c7b018fdbcf427a9f3079c32468c6..8c3fb3f3ec2f1de1a7e7ac2a8466c7f630dead9e 100644 --- a/paddlepalm/reader/base_reader.py +++ b/paddlepalm/reader/base_reader.py @@ -12,14 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""v1.1""" + from copy import copy class Reader(object): - """interface of data manager.""" + """interface of data reader.""" def __init__(self, phase='train'): - # assert isinstance(config, dict) - # self._config = config + """该函数完成一个Reader的构造,至少需要包含一个phase参数。 + 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 + Args: + phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict + """ + self._phase = phase self._batch_size = None self._num_epochs = 1 @@ -31,6 +35,7 @@ class Reader(object): return set() def clone(self, phase='train'): + """拷贝一个新的reader对象。""" if phase == self._phase: return copy(self) else: @@ -39,14 +44,25 @@ class Reader(object): return ret def require_attr(self, attr_name): + """在注册器中新增一个需要产生的对象。 + + Args: + attr_name: 需要产出的对象的对象名,例如’segment_ids‘。 + """ self._register.add(attr_name) def register_with(self, backbone): + """根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。 + + Args: + backbone: 需要对接的主干网络。 + """ for attr in backbone.inputs_attr: self.require_attr(attr) self._registered_backbone = backbone def get_registered_backbone(self): + """返回该reader所注册的backbone。""" return self._registered_backbone def _get_registed_attrs(self, attrs): @@ -57,27 +73,27 @@ class Reader(object): ret[i] = attrs[i] return ret - # @property - # def inputs_attr(self): - # """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. - # Return: - # dict类型。对各个输入对象的属性描述。例如, - # 对于文本分类任务,可能需要包含输入文本和所属标签的id - # {"text": ([], 'str'), - # "label": ([], 'int')} - # 对于标注任务,可能需要输入词序列和对应的标签 - # {"tokens", ([-1], 'str'), - # "tags", ([-1], 'str')} - # 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 - # {"paragraph", ([], 'str'), - # "question", ([], 'str'), - # "start_position", ([], 'int') - # """ - # raise NotImplementedError() + def load_data(self, input_file, batch_size, num_epochs=None, \ + file_format='tsv', shuffle_train=True): + """Load data into reader. + + Noted that it requires the creation of self._batch_size and self._num_epochs when this method implemented. + + Args: + input_file: the dataset file path. File format should meet the requirement of `file_format` argument. + batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices + (marked as dev_count), the batch_size should be divided by dev_count with no remainder! + num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning + and automatically calculated for multi-task learning. This argument only works on train phase. + file_format: the file format of input file. Supported format: tsv. Default is tsv. + shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase. + """ + raise NotImplementedError() @property def outputs_attr(self): - """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据 + 类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) Return: dict类型。对各个输入对象的属性描述。例如, @@ -89,37 +105,25 @@ class Reader(object): "label": ([-1], 'int')} """ raise NotImplementedError() - - # def parse_line(self): - # """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。 - # 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。 - # 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。 - # """ - # raise NotImplementedError() - # - # def tokenize(self, line): - # """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。 - # Args: - # - line: a unicode string. - # Return: - # a list of tokens - # """ - # raise NotImplementedError() - def iterator(self): + def _iterator(self): """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 Yield: - (dict) elements that meet the requirements in output_templete + dict类型。符合outputs_attr描述的当前step的输出对象。 """ raise NotImplementedError() + def get_epoch_outputs(self): + """返回数据集每个epoch遍历后的输出对象。""" + raise NotImplementedError() + @property def num_examples(self): - """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" + """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时 + 该接口应返回runtime阶段的实际样本数。""" raise NotImplementedError() @property def num_epochs(self): - """""" - raise NotImplementedError() - + """数据集遍历次数""" + return self._num_epochs