未验证 提交 091c3698 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update base_reader.py

上级 abb108dc
...@@ -12,14 +12,18 @@ ...@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""v1.1"""
from copy import copy from copy import copy
class Reader(object): class Reader(object):
"""interface of data manager.""" """interface of data reader."""
def __init__(self, phase='train'): def __init__(self, phase='train'):
# assert isinstance(config, dict) """该函数完成一个Reader的构造,至少需要包含一个phase参数。
# self._config = config 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
Args:
phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict
"""
self._phase = phase self._phase = phase
self._batch_size = None self._batch_size = None
self._num_epochs = 1 self._num_epochs = 1
...@@ -31,6 +35,7 @@ class Reader(object): ...@@ -31,6 +35,7 @@ class Reader(object):
return set() return set()
def clone(self, phase='train'): def clone(self, phase='train'):
"""拷贝一个新的reader对象。"""
if phase == self._phase: if phase == self._phase:
return copy(self) return copy(self)
else: else:
...@@ -39,14 +44,25 @@ class Reader(object): ...@@ -39,14 +44,25 @@ class Reader(object):
return ret return ret
def require_attr(self, attr_name): def require_attr(self, attr_name):
"""在注册器中新增一个需要产生的对象。
Args:
attr_name: 需要产出的对象的对象名,例如’segment_ids‘。
"""
self._register.add(attr_name) self._register.add(attr_name)
def register_with(self, backbone): def register_with(self, backbone):
"""根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。
Args:
backbone: 需要对接的主干网络。
"""
for attr in backbone.inputs_attr: for attr in backbone.inputs_attr:
self.require_attr(attr) self.require_attr(attr)
self._registered_backbone = backbone self._registered_backbone = backbone
def get_registered_backbone(self): def get_registered_backbone(self):
"""返回该reader所注册的backbone。"""
return self._registered_backbone return self._registered_backbone
def _get_registed_attrs(self, attrs): def _get_registed_attrs(self, attrs):
...@@ -57,27 +73,27 @@ class Reader(object): ...@@ -57,27 +73,27 @@ class Reader(object):
ret[i] = attrs[i] ret[i] = attrs[i]
return ret return ret
# @property def load_data(self, input_file, batch_size, num_epochs=None, \
# def inputs_attr(self): file_format='tsv', shuffle_train=True):
# """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. """Load data into reader.
# Return:
# dict类型。对各个输入对象的属性描述。例如, Noted that it requires the creation of self._batch_size and self._num_epochs when this method implemented.
# 对于文本分类任务,可能需要包含输入文本和所属标签的id
# {"text": ([], 'str'), Args:
# "label": ([], 'int')} input_file: the dataset file path. File format should meet the requirement of `file_format` argument.
# 对于标注任务,可能需要输入词序列和对应的标签 batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices
# {"tokens", ([-1], 'str'), (marked as dev_count), the batch_size should be divided by dev_count with no remainder!
# "tags", ([-1], 'str')} num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning
# 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 and automatically calculated for multi-task learning. This argument only works on train phase.
# {"paragraph", ([], 'str'), file_format: the file format of input file. Supported format: tsv. Default is tsv.
# "question", ([], 'str'), shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase.
# "start_position", ([], 'int') """
# """ raise NotImplementedError()
# raise NotImplementedError()
@property @property
def outputs_attr(self): def outputs_attr(self):
"""描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据
类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)
Return: Return:
dict类型。对各个输入对象的属性描述。例如, dict类型。对各个输入对象的属性描述。例如,
...@@ -89,37 +105,25 @@ class Reader(object): ...@@ -89,37 +105,25 @@ class Reader(object):
"label": ([-1], 'int')} "label": ([-1], 'int')}
""" """
raise NotImplementedError() raise NotImplementedError()
# def parse_line(self):
# """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。
# 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。
# 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。
# """
# raise NotImplementedError()
#
# def tokenize(self, line):
# """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。
# Args:
# - line: a unicode string.
# Return:
# a list of tokens
# """
# raise NotImplementedError()
def iterator(self): def _iterator(self):
"""数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。
Yield: Yield:
(dict) elements that meet the requirements in output_templete dict类型。符合outputs_attr描述的当前step的输出对象。
""" """
raise NotImplementedError() raise NotImplementedError()
def get_epoch_outputs(self):
"""返回数据集每个epoch遍历后的输出对象。"""
raise NotImplementedError()
@property @property
def num_examples(self): def num_examples(self):
"""数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时
该接口应返回runtime阶段的实际样本数。"""
raise NotImplementedError() raise NotImplementedError()
@property @property
def num_epochs(self): def num_epochs(self):
"""""" """数据集遍历次数"""
raise NotImplementedError() return self._num_epochs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册