Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
74ec0647
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
74ec0647
编写于
4月 08, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reorg finetune and reader
上级
810cdd3a
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
335 addition
and
624 deletion
+335
-624
demo/ernie-classification/finetune_with_hub.py
demo/ernie-classification/finetune_with_hub.py
+1
-0
paddlehub/__init__.py
paddlehub/__init__.py
+0
-3
paddlehub/common/__init__.py
paddlehub/common/__init__.py
+1
-0
paddlehub/common/utils.py
paddlehub/common/utils.py
+13
-0
paddlehub/finetune/config.py
paddlehub/finetune/config.py
+1
-1
paddlehub/finetune/evaluate.py
paddlehub/finetune/evaluate.py
+10
-2
paddlehub/finetune/finetune.py
paddlehub/finetune/finetune.py
+3
-15
paddlehub/reader/__init__.py
paddlehub/reader/__init__.py
+3
-3
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+303
-198
paddlehub/reader/task_reader.py
paddlehub/reader/task_reader.py
+0
-402
未找到文件。
demo/ernie-classification/finetune_with_hub.py
浏览文件 @
74ec0647
...
...
@@ -49,6 +49,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API
config
=
hub
.
RunConfig
(
eval_interval
=
10
,
use_cuda
=
True
,
num_epoch
=
args
.
num_epoch
,
batch_size
=
args
.
batch_size
,
...
...
paddlehub/__init__.py
浏览文件 @
74ec0647
...
...
@@ -40,6 +40,3 @@ from .finetune.finetune import finetune_and_eval
from
.finetune.config
import
RunConfig
from
.finetune.strategy
import
BERTFinetuneStrategy
from
.finetune.strategy
import
DefaultStrategy
from
.reader
import
BERTTokenizeReader
from
.reader.cv_reader
import
ImageClassificationReader
paddlehub/common/__init__.py
浏览文件 @
74ec0647
...
...
@@ -13,3 +13,4 @@
# limitations under the License.
from
.
import
utils
from
.utils
import
get_running_device_info
paddlehub/common/utils.py
浏览文件 @
74ec0647
...
...
@@ -17,6 +17,8 @@ from __future__ import division
from
__future__
import
print_function
import
os
import
time
import
multiprocessing
import
hashlib
import
paddle
...
...
@@ -185,6 +187,17 @@ def is_yaml_file(file_path):
return
get_file_ext
(
file_path
)
==
".yml"
def
get_running_device_info
(
config
):
if
config
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
return
place
,
dev_count
if
__name__
==
"__main__"
:
print
(
is_yaml_file
(
"test.yml"
))
print
(
is_csv_file
(
"test.yml"
))
...
...
paddlehub/finetune/config.py
浏览文件 @
74ec0647
...
...
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
datetime
import
datetime
import
time
from
datetime
import
datetime
from
paddlehub.finetune.strategy
import
DefaultStrategy
from
paddlehub.common.logger
import
logger
...
...
paddlehub/finetune/evaluate.py
浏览文件 @
74ec0647
...
...
@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddlehub.common.logger
import
logger
def
evaluate_cls_task
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
None
):
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
phase
))
...
...
@@ -20,7 +28,7 @@ def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None):
loss
=
task
.
variable
(
"loss"
)
accuracy
=
task
.
variable
(
"accuracy"
)
batch_size
=
config
.
batch_size
place
,
dev_count
=
_
get_running_device_info
(
config
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
exe
=
fluid
.
Executor
(
place
=
place
)
with
fluid
.
program_guard
(
inference_program
):
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
...
...
@@ -64,7 +72,7 @@ def evaluate_seq_labeling_task(task,
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
phase
))
inference_program
=
task
.
inference_program
()
batch_size
=
config
.
batch_size
place
,
dev_count
=
_
get_running_device_info
(
config
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
exe
=
fluid
.
Executor
(
place
=
place
)
num_labels
=
len
(
data_reader
.
get_labels
())
with
fluid
.
program_guard
(
inference_program
):
...
...
paddlehub/finetune/finetune.py
浏览文件 @
74ec0647
...
...
@@ -18,10 +18,10 @@ from __future__ import print_function
import
os
import
time
import
multiprocessing
import
paddle
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
import
numpy
as
np
from
paddlehub.common.logger
import
logger
...
...
@@ -29,18 +29,6 @@ from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy
from
paddlehub.finetune.checkpoint
import
load_checkpoint
,
save_checkpoint
from
paddlehub.finetune.evaluate
import
evaluate_cls_task
,
evaluate_seq_labeling_task
from
visualdl
import
LogWriter
import
paddlehub
as
hub
def
_get_running_device_info
(
config
):
if
config
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
return
place
,
dev_count
def
_do_memory_optimization
(
task
,
config
):
...
...
@@ -80,7 +68,7 @@ def _finetune_seq_label_task(task,
num_epoch
=
config
.
num_epoch
batch_size
=
config
.
batch_size
place
,
dev_count
=
_
get_running_device_info
(
config
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
exe
=
fluid
.
Executor
(
place
=
place
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
...
...
@@ -177,7 +165,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None,
log_writter
=
LogWriter
(
os
.
path
.
join
(
config
.
checkpoint_dir
,
"vdllog"
),
sync_cycle
=
10
)
place
,
dev_count
=
_
get_running_device_info
(
config
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
exe
=
fluid
.
Executor
(
place
=
place
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
...
...
paddlehub/reader/__init__.py
浏览文件 @
74ec0647
...
...
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.nlp_reader
import
BERTTokenize
Reader
from
.
task_reader
import
Classify
Reader
from
.
task_reader
import
SequenceLabel
Reader
from
.nlp_reader
import
Classify
Reader
from
.
nlp_reader
import
SequenceLabel
Reader
from
.
cv_reader
import
ImageClassification
Reader
paddlehub/reader/nlp_reader.py
浏览文件 @
74ec0647
此差异已折叠。
点击以展开。
paddlehub/reader/task_reader.py
已删除
100644 → 0
浏览文件 @
810cdd3a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
csv
import
json
import
numpy
as
np
from
collections
import
namedtuple
from
paddlehub.reader
import
tokenization
from
.batching
import
pad_batch_data
class
BaseReader
(
object
):
def
__init__
(
self
,
dataset
,
vocab_path
,
label_map_config
=
None
,
max_seq_len
=
512
,
do_lower_case
=
True
,
in_tokens
=
False
,
random_seed
=
None
):
self
.
max_seq_len
=
max_seq_len
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_path
,
do_lower_case
=
do_lower_case
)
self
.
vocab
=
self
.
tokenizer
.
vocab
self
.
dataset
=
dataset
self
.
pad_id
=
self
.
vocab
[
"[PAD]"
]
self
.
cls_id
=
self
.
vocab
[
"[CLS]"
]
self
.
sep_id
=
self
.
vocab
[
"[SEP]"
]
self
.
in_tokens
=
in_tokens
np
.
random
.
seed
(
random_seed
)
# generate label map
self
.
label_map
=
{}
for
index
,
label
in
enumerate
(
self
.
dataset
.
get_labels
()):
self
.
label_map
[
label
]
=
index
print
(
"Dataset label map = {}"
.
format
(
self
.
label_map
))
self
.
current_example
=
0
self
.
current_epoch
=
0
self
.
num_examples
=
0
# if label_map_config:
# with open(label_map_config) as f:
# self.label_map = json.load(f)
# else:
# self.label_map = None
self
.
num_examples
=
{
'train'
:
-
1
,
'dev'
:
-
1
,
'test'
:
-
1
}
def
get_train_examples
(
self
):
"""Gets a collection of `InputExample`s for the train set."""
return
self
.
dataset
.
get_train_examples
()
def
get_dev_examples
(
self
):
"""Gets a collection of `InputExample`s for the dev set."""
return
self
.
dataset
.
get_dev_examples
()
def
get_val_examples
(
self
):
"""Gets a collection of `InputExample`s for the val set."""
return
self
.
dataset
.
get_val_examples
()
def
get_test_examples
(
self
):
"""Gets a collection of `InputExample`s for prediction."""
return
self
.
dataset
.
get_test_examples
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
return
self
.
dataset
.
get_labels
()
def
get_train_progress
(
self
):
"""Gets progress for training phase."""
return
self
.
current_example
,
self
.
current_epoch
def
_truncate_seq_pair
(
self
,
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
tokens_b
.
pop
()
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
):
"""Converts a single `Example` into a single `Record`."""
text_a
=
tokenization
.
convert_to_unicode
(
example
.
text_a
)
tokens_a
=
tokenizer
.
tokenize
(
text_a
)
tokens_b
=
None
if
example
.
text_b
is
not
None
:
#if "text_b" in example._fields:
text_b
=
tokenization
.
convert_to_unicode
(
example
.
text_b
)
tokens_b
=
tokenizer
.
tokenize
(
text_b
)
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self
.
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
3
)
else
:
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
# The convention in BERT/ERNIE is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
[]
text_type_ids
=
[]
tokens
.
append
(
"[CLS]"
)
text_type_ids
.
append
(
0
)
for
token
in
tokens_a
:
tokens
.
append
(
token
)
text_type_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
text_type_ids
.
append
(
0
)
if
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
text_type_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
text_type_ids
.
append
(
1
)
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
if
self
.
label_map
:
label_id
=
self
.
label_map
[
example
.
label
]
else
:
label_id
=
example
.
label
# Record = namedtuple(
# 'Record',
# ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
# qid = None
# if "qid" in example._fields:
# qid = example.qid
# record = Record(
# token_ids=token_ids,
# text_type_ids=text_type_ids,
# position_ids=position_ids,
# label_id=label_id,
# qid=qid)
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_id'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_id
=
label_id
)
return
record
def
_prepare_batch_data
(
self
,
examples
,
batch_size
,
phase
=
None
):
"""generate batch records"""
batch_records
,
max_len
=
[],
0
for
index
,
example
in
enumerate
(
examples
):
if
phase
==
"train"
:
self
.
current_example
=
index
record
=
self
.
_convert_example_to_record
(
example
,
self
.
max_seq_len
,
self
.
tokenizer
)
max_len
=
max
(
max_len
,
len
(
record
.
token_ids
))
if
self
.
in_tokens
:
to_append
=
(
len
(
batch_records
)
+
1
)
*
max_len
<=
batch_size
else
:
to_append
=
len
(
batch_records
)
<
batch_size
if
to_append
:
batch_records
.
append
(
record
)
else
:
yield
self
.
_pad_batch_records
(
batch_records
)
batch_records
,
max_len
=
[
record
],
len
(
record
.
token_ids
)
if
batch_records
:
yield
self
.
_pad_batch_records
(
batch_records
)
# def get_num_examples(self, input_file):
# examples = self._read_tsv(input_file)
# return len(examples)
def
get_num_examples
(
self
,
phase
):
"""Get number of examples for train, dev or test."""
if
phase
not
in
[
'train'
,
'val'
,
'dev'
,
'test'
]:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'val'/'dev', 'test']."
)
return
self
.
num_examples
[
phase
]
def
data_generator
(
self
,
batch_size
,
phase
=
'train'
,
shuffle
=
True
):
if
phase
==
'train'
:
examples
=
self
.
get_train_examples
()
self
.
num_examples
[
'train'
]
=
len
(
examples
)
elif
phase
==
'val'
or
phase
==
'dev'
:
examples
=
self
.
get_dev_examples
()
self
.
num_examples
[
'dev'
]
=
len
(
examples
)
elif
phase
==
'test'
:
examples
=
self
.
get_test_examples
()
self
.
num_examples
[
'test'
]
=
len
(
examples
)
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
def
wrapper
():
if
shuffle
:
np
.
random
.
shuffle
(
examples
)
for
batch_data
in
self
.
_prepare_batch_data
(
examples
,
batch_size
,
phase
=
phase
):
yield
[
batch_data
]
return
wrapper
class
ClassifyReader
(
BaseReader
):
def
_pad_batch_records
(
self
,
batch_records
):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
batch_labels
=
[
record
.
label_id
for
record
in
batch_records
]
batch_labels
=
np
.
array
(
batch_labels
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids
,
input_mask
=
pad_batch_data
(
batch_token_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
,
return_input_mask
=
True
)
padded_text_type_ids
=
pad_batch_data
(
batch_text_type_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
padded_position_ids
=
pad_batch_data
(
batch_position_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
batch_labels
]
return
return_list
class
SequenceLabelReader
(
BaseReader
):
def
_pad_batch_records
(
self
,
batch_records
):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
batch_label_ids
=
[
record
.
label_ids
for
record
in
batch_records
]
# padding
padded_token_ids
,
input_mask
,
batch_seq_lens
=
pad_batch_data
(
batch_token_ids
,
pad_idx
=
self
.
pad_id
,
max_seq_len
=
self
.
max_seq_len
,
return_input_mask
=
True
,
return_seq_lens
=
True
)
padded_text_type_ids
=
pad_batch_data
(
batch_text_type_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
padded_position_ids
=
pad_batch_data
(
batch_position_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
padded_label_ids
=
pad_batch_data
(
batch_label_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
len
(
self
.
label_map
)
-
1
)
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
padded_label_ids
,
batch_seq_lens
]
return
return_list
def
_reseg_token_label
(
self
,
tokens
,
labels
,
tokenizer
):
assert
len
(
tokens
)
==
len
(
labels
)
ret_tokens
=
[]
ret_labels
=
[]
for
token
,
label
in
zip
(
tokens
,
labels
):
sub_token
=
tokenizer
.
tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
ret_labels
.
append
(
label
)
if
len
(
sub_token
)
<
2
:
continue
sub_label
=
label
if
label
.
startswith
(
"B-"
):
sub_label
=
"I-"
+
label
[
2
:]
ret_labels
.
extend
([
sub_label
]
*
(
len
(
sub_token
)
-
1
))
assert
len
(
ret_tokens
)
==
len
(
ret_labels
)
return
ret_tokens
,
ret_labels
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
):
tokens
=
tokenization
.
convert_to_unicode
(
example
.
text_a
).
split
(
u
""
)
labels
=
tokenization
.
convert_to_unicode
(
example
.
label
).
split
(
u
""
)
tokens
,
labels
=
self
.
_reseg_token_label
(
tokens
,
labels
,
tokenizer
)
if
len
(
tokens
)
>
max_seq_length
-
2
:
tokens
=
tokens
[
0
:(
max_seq_length
-
2
)]
labels
=
labels
[
0
:(
max_seq_length
-
2
)]
tokens
=
[
"[CLS]"
]
+
tokens
+
[
"[SEP]"
]
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
text_type_ids
=
[
0
]
*
len
(
token_ids
)
no_entity_id
=
len
(
self
.
label_map
)
-
1
label_ids
=
[
no_entity_id
]
+
[
self
.
label_map
[
label
]
for
label
in
labels
]
+
[
no_entity_id
]
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_ids
=
label_ids
)
return
record
class
ExtractEmbeddingReader
(
BaseReader
):
def
_pad_batch_records
(
self
,
batch_records
):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
# padding
padded_token_ids
,
input_mask
,
seq_lens
=
pad_batch_data
(
batch_token_ids
,
pad_idx
=
self
.
pad_id
,
max_seq_len
=
self
.
max_seq_len
,
return_input_mask
=
True
,
return_seq_lens
=
True
)
padded_text_type_ids
=
pad_batch_data
(
batch_text_type_ids
,
pad_idx
=
self
.
pad_id
,
max_seq_len
=
self
.
max_seq_len
)
padded_position_ids
=
pad_batch_data
(
batch_position_ids
,
pad_idx
=
self
.
pad_id
,
max_seq_len
=
self
.
max_seq_len
)
return_list
=
[
padded_token_ids
,
padded_text_type_ids
,
padded_position_ids
,
input_mask
,
seq_lens
]
return
return_list
if
__name__
==
'__main__'
:
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录