Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
091c3698
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
7
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
091c3698
编写于
3月 30, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
3月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update base_reader.py
上级
abb108dc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
48 addition
and
44 deletion
+48
-44
paddlepalm/reader/base_reader.py
paddlepalm/reader/base_reader.py
+48
-44
未找到文件。
paddlepalm/reader/base_reader.py
浏览文件 @
091c3698
...
...
@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""v1.1"""
from
copy
import
copy
class
Reader
(
object
):
"""interface of data
manag
er."""
"""interface of data
read
er."""
def
__init__
(
self
,
phase
=
'train'
):
# assert isinstance(config, dict)
# self._config = config
"""该函数完成一个Reader的构造,至少需要包含一个phase参数。
注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
Args:
phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict
"""
self
.
_phase
=
phase
self
.
_batch_size
=
None
self
.
_num_epochs
=
1
...
...
@@ -31,6 +35,7 @@ class Reader(object):
return
set
()
def
clone
(
self
,
phase
=
'train'
):
"""拷贝一个新的reader对象。"""
if
phase
==
self
.
_phase
:
return
copy
(
self
)
else
:
...
...
@@ -39,14 +44,25 @@ class Reader(object):
return
ret
def
require_attr
(
self
,
attr_name
):
"""在注册器中新增一个需要产生的对象。
Args:
attr_name: 需要产出的对象的对象名,例如’segment_ids‘。
"""
self
.
_register
.
add
(
attr_name
)
def
register_with
(
self
,
backbone
):
"""根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。
Args:
backbone: 需要对接的主干网络。
"""
for
attr
in
backbone
.
inputs_attr
:
self
.
require_attr
(
attr
)
self
.
_registered_backbone
=
backbone
def
get_registered_backbone
(
self
):
"""返回该reader所注册的backbone。"""
return
self
.
_registered_backbone
def
_get_registed_attrs
(
self
,
attrs
):
...
...
@@ -57,27 +73,27 @@ class Reader(object):
ret
[
i
]
=
attrs
[
i
]
return
ret
# @property
# def inputs_attr(self):
# """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1.
# Return:
# dict类型。对各个输入对象的属性描述。例如,
# 对于文本分类任务,可能需要包含输入文本和所属标签的id
# {"text": ([], 'str'),
# "label": ([], 'int')}
# 对于标注任务,可能需要输入词序列和对应的标签
# {"tokens", ([-1], 'str'),
# "tags", ([-1], 'str')}
# 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等
# {"paragraph", ([], 'str'),
# "question", ([], 'str'),
# "start_position", ([], 'int')
# """
# raise NotImplementedError()
def
load_data
(
self
,
input_file
,
batch_size
,
num_epochs
=
None
,
\
file_format
=
'tsv'
,
shuffle_train
=
True
):
"""Load data into reader.
Noted that it requires the creation of self._batch_size and self._num_epochs when this method implemented.
Args:
input_file: the dataset file path. File format should meet the requirement of `file_format` argument.
batch_size: number of examples for once yield. CAUSIOUS! If your environment exists multiple GPU devices
(marked as dev_count), the batch_size should be divided by dev_count with no remainder!
num_epochs: the travelsal times of input examples. Default is None, means once for single-task learning
and automatically calculated for multi-task learning. This argument only works on train phase.
file_format: the file format of input file. Supported format: tsv. Default is tsv.
shuffle_train: whether to shuffle training dataset. Default is True. This argument only works on training phase.
"""
raise
NotImplementedError
()
@
property
def
outputs_attr
(
self
):
"""描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
"""描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据
类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)
Return:
dict类型。对各个输入对象的属性描述。例如,
...
...
@@ -89,37 +105,25 @@ class Reader(object):
"label": ([-1], 'int')}
"""
raise
NotImplementedError
()
# def parse_line(self):
# """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。
# 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。
# 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。
# """
# raise NotImplementedError()
#
# def tokenize(self, line):
# """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。
# Args:
# - line: a unicode string.
# Return:
# a list of tokens
# """
# raise NotImplementedError()
def
iterator
(
self
):
def
_
iterator
(
self
):
"""数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。
Yield:
(dict) elements that meet the requirements in output_templete
dict类型。符合outputs_attr描述的当前step的输出对象。
"""
raise
NotImplementedError
()
def
get_epoch_outputs
(
self
):
"""返回数据集每个epoch遍历后的输出对象。"""
raise
NotImplementedError
()
@
property
def
num_examples
(
self
):
"""数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。"""
"""数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时
该接口应返回runtime阶段的实际样本数。"""
raise
NotImplementedError
()
@
property
def
num_epochs
(
self
):
""""""
raise
NotImplementedError
()
"""数据集遍历次数"""
return
self
.
_num_epochs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录