Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
300be16c
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
300be16c
编写于
1月 08, 2020
作者:
P
pkpk
提交者:
GitHub
1月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop (#4175)
上级
071dc299
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
192 addition
and
143 deletion
+192
-143
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/reader.py
...ddleDialogue/dialogue_general_understanding/dgu/reader.py
+192
-143
未找到文件。
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/reader.py
浏览文件 @
300be16c
...
...
@@ -23,19 +23,21 @@ import numpy as np
from
dgu
import
tokenization
from
dgu.batching
import
prepare_batch_data
reload
(
sys
)
sys
.
setdefaultencoding
(
'utf-8'
)
if
sys
.
version
[
0
]
==
'2'
:
reload
(
sys
)
sys
.
setdefaultencoding
(
'utf-8'
)
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
__init__
(
self
,
data_dir
,
vocab_path
,
max_seq_len
,
do_lower_case
,
def
__init__
(
self
,
data_dir
,
vocab_path
,
max_seq_len
,
do_lower_case
,
in_tokens
,
task_name
,
task_name
,
random_seed
=
None
):
self
.
data_dir
=
data_dir
self
.
max_seq_len
=
max_seq_len
...
...
@@ -92,7 +94,7 @@ class DataProcessor(object):
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
):
return_num_token
=
False
):
"""generate batch data"""
return
prepare_batch_data
(
self
.
task_name
,
...
...
@@ -114,7 +116,7 @@ class DataProcessor(object):
f
=
io
.
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
for
line
in
reader
:
lines
.
append
(
line
)
return
lines
...
...
@@ -147,21 +149,21 @@ class DataProcessor(object):
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
def
instance_reader
():
def
instance_reader
():
"""generate instance data"""
if
shuffle
:
np
.
random
.
shuffle
(
examples
)
for
(
index
,
example
)
in
enumerate
(
examples
):
feature
=
self
.
convert_example
(
index
,
example
,
self
.
get_labels
(),
self
.
max_seq_len
,
self
.
tokenizer
)
for
(
index
,
example
)
in
enumerate
(
examples
):
feature
=
self
.
convert_example
(
index
,
example
,
self
.
get_labels
()
,
self
.
max_seq_len
,
self
.
tokenizer
)
instance
=
self
.
generate_instance
(
feature
)
yield
instance
def
batch_reader
(
reader
,
batch_size
,
in_tokens
):
def
batch_reader
(
reader
,
batch_size
,
in_tokens
):
"""read batch data"""
batch
,
total_token_num
,
max_len
=
[],
0
,
0
for
instance
in
reader
():
for
instance
in
reader
():
token_ids
,
sent_ids
,
pos_ids
,
label
=
instance
[:
4
]
max_len
=
max
(
max_len
,
len
(
token_ids
))
if
in_tokens
:
...
...
@@ -179,13 +181,13 @@ class DataProcessor(object):
if
len
(
batch
)
>
0
:
yield
batch
,
total_token_num
def
wrapper
():
def
wrapper
():
"""yield batch data to network"""
for
batch_data
,
total_token_num
in
batch_reader
(
instance_reader
,
batch_size
,
self
.
in_tokens
):
if
self
.
in_tokens
:
instance_reader
,
batch_size
,
self
.
in_tokens
):
if
self
.
in_tokens
:
max_seq
=
-
1
else
:
else
:
max_seq
=
self
.
max_seq_len
batch_data
=
self
.
generate_batch_data
(
batch_data
,
...
...
@@ -199,7 +201,7 @@ class DataProcessor(object):
yield
batch_data
return
wrapper
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
...
...
@@ -250,19 +252,24 @@ class InputFeatures(object):
self
.
label_id
=
label_id
class
UDCProcessor
(
DataProcessor
):
class
UDCProcessor
(
DataProcessor
):
"""Processor for the UDC data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
print
(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
<
3
:
print
(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
<
3
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains at least three parts: label
\t
conv1
\t
.....
\t
response"
)
print
(
"data row contains at least three parts: label
\t
conv1
\t
.....
\t
response"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
"
\t
"
.
join
(
line
[
1
:
-
1
])
text_a
=
"
\t
"
.
join
(
line
[
1
:
-
1
])
text_a
=
tokenization
.
convert_to_unicode
(
text_a
)
text_a
=
text_a
.
split
(
'
\t
'
)
text_b
=
line
[
-
1
]
...
...
@@ -273,21 +280,21 @@ class UDCProcessor(DataProcessor):
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.txt"
))
examples
=
self
.
_create_examples
(
lines
,
"train"
)
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.txt"
))
examples
=
self
.
_create_examples
(
lines
,
"dev"
)
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.txt"
))
...
...
@@ -295,19 +302,20 @@ class UDCProcessor(DataProcessor):
return
examples
@
staticmethod
def
get_labels
():
def
get_labels
():
"""See base class."""
return
[
"0"
,
"1"
]
class
SWDAProcessor
(
DataProcessor
):
class
SWDAProcessor
(
DataProcessor
):
"""Processor for the SWDA data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
create_multi_turn_examples
(
lines
,
set_type
)
return
examples
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.txt"
))
...
...
@@ -329,21 +337,22 @@ class SWDAProcessor(DataProcessor):
return
examples
@
staticmethod
def
get_labels
():
def
get_labels
():
"""See base class."""
labels
=
range
(
42
)
labels
=
[
str
(
label
)
for
label
in
labels
]
return
labels
class
MRDAProcessor
(
DataProcessor
):
class
MRDAProcessor
(
DataProcessor
):
"""Processor for the MRDA data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
create_multi_turn_examples
(
lines
,
set_type
)
return
examples
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.txt"
))
...
...
@@ -365,22 +374,25 @@ class MRDAProcessor(DataProcessor):
return
examples
@
staticmethod
def
get_labels
():
def
get_labels
():
"""See base class."""
labels
=
range
(
42
)
labels
=
[
str
(
label
)
for
label
in
labels
]
return
labels
class
ATISSlotProcessor
(
DataProcessor
):
class
ATISSlotProcessor
(
DataProcessor
):
"""Processor for the ATIS Slot data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains two parts: conversation_content
\t
label1 label2 label3"
)
print
(
"data row contains two parts: conversation_content
\t
label1 label2 label3"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
line
[
0
]
...
...
@@ -392,7 +404,7 @@ class ATISSlotProcessor(DataProcessor):
guid
=
guid
,
text_a
=
text_a
,
label
=
label_list
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.txt"
))
...
...
@@ -414,30 +426,30 @@ class ATISSlotProcessor(DataProcessor):
return
examples
@
staticmethod
def
get_labels
():
def
get_labels
():
"""See base class."""
labels
=
range
(
130
)
labels
=
[
str
(
label
)
for
label
in
labels
]
return
labels
class
ATISIntentProcessor
(
DataProcessor
):
class
ATISIntentProcessor
(
DataProcessor
):
"""Processor for the ATIS intent data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains two parts: label
\t
conversation_content"
)
print
(
"data row contains two parts: label
\t
conversation_content"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
line
[
1
]
text_a
=
tokenization
.
convert_to_unicode
(
text_a
)
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
label
=
label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -469,53 +481,60 @@ class ATISIntentProcessor(DataProcessor):
return
labels
class
DSTC2Processor
(
DataProcessor
):
class
DSTC2Processor
(
DataProcessor
):
"""Processor for the DSTC2 data set."""
def
_create_turns
(
self
,
conv_example
):
def
_create_turns
(
self
,
conv_example
):
"""create multi turn dataset"""
samples
=
[]
max_turns
=
20
for
i
in
range
(
len
(
conv_example
)):
conv_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
+
1
]
for
i
in
range
(
len
(
conv_example
)):
conv_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
+
1
]
conv_info
=
"
\1
"
.
join
([
sample
[
0
]
for
sample
in
conv_turns
])
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
return
samples
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for multi-turn dialogue sets."""
examples
=
[]
conv_id
=
-
1
index
=
0
conv_example
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
3
:
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
3
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains three parts: conversation_content
\t
question
\1
answer
\t
state1 state2 state3......"
)
print
(
"data row contains three parts: conversation_content
\t
question
\1
answer
\t
state1 state2 state3......"
)
continue
conv_no
=
line
[
0
]
text_a
=
line
[
1
]
label_list
=
line
[
2
].
split
()
if
conv_no
!=
conv_id
and
i
!=
0
:
if
conv_no
!=
conv_id
and
i
!=
0
:
samples
=
self
.
_create_turns
(
conv_example
)
for
sample
in
samples
:
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
index
+=
1
history
=
sample
[
0
]
dst_label
=
sample
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
conv_id
=
conv_no
conv_example
.
append
((
text_a
,
label_list
))
if
conv_example
:
if
conv_example
:
samples
=
self
.
_create_turns
(
conv_example
)
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
index
+=
1
history
=
sample
[
0
]
dst_label
=
sample
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -547,20 +566,22 @@ class DSTC2Processor(DataProcessor):
return
labels
class
MULTIWOZProcessor
(
DataProcessor
):
class
MULTIWOZProcessor
(
DataProcessor
):
"""Processor for the MULTIWOZ data set."""
def
_create_turns
(
self
,
conv_example
):
def
_create_turns
(
self
,
conv_example
):
"""create multi turn dataset"""
samples
=
[]
max_turns
=
2
for
i
in
range
(
len
(
conv_example
)):
prefix_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
]
prefix_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
]
conv_info
=
"
\1
"
.
join
([
turn
[
0
]
for
turn
in
prefix_turns
])
current_turns
=
conv_example
[
i
][
0
]
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
current_turns
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
current_turns
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
return
samples
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for multi-turn dialogue sets."""
examples
=
[]
conv_id
=
-
1
...
...
@@ -570,7 +591,7 @@ class MULTIWOZProcessor(DataProcessor):
conv_no
=
line
[
0
]
text_a
=
line
[
2
]
label_list
=
line
[
1
].
split
()
if
conv_no
!=
conv_id
and
i
!=
0
:
if
conv_no
!=
conv_id
and
i
!=
0
:
samples
=
self
.
_create_turns
(
conv_example
)
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
...
...
@@ -578,13 +599,18 @@ class MULTIWOZProcessor(DataProcessor):
history
=
sample
[
0
]
current
=
sample
[
1
]
dst_label
=
sample
[
2
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
if
i
==
0
:
conv_id
=
conv_no
conv_example
.
append
((
text_a
,
label_list
))
if
conv_example
:
if
conv_example
:
samples
=
self
.
_create_turns
(
conv_example
)
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
...
...
@@ -592,10 +618,15 @@ class MULTIWOZProcessor(DataProcessor):
history
=
sample
[
0
]
current
=
sample
[
1
]
dst_label
=
sample
[
2
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
examples
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.txt"
))
...
...
@@ -624,34 +655,38 @@ class MULTIWOZProcessor(DataProcessor):
return
labels
def
create_dialogue_examples
(
conv
):
def
create_dialogue_examples
(
conv
):
"""Creates dialogue sample"""
samples
=
[]
for
i
in
range
(
len
(
conv
)):
for
i
in
range
(
len
(
conv
)):
cur_txt
=
"%s : %s"
%
(
conv
[
i
][
2
],
conv
[
i
][
3
])
pre_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
max
(
0
,
i
-
5
):
i
]]
suf_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
i
+
1
:
min
(
len
(
conv
),
i
+
3
)]]
pre_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
max
(
0
,
i
-
5
):
i
]]
suf_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
i
+
1
:
min
(
len
(
conv
),
i
+
3
)]
]
sample
=
[
conv
[
i
][
1
],
pre_txt
,
cur_txt
,
suf_txt
]
samples
.
append
(
sample
)
return
samples
def
create_multi_turn_examples
(
lines
,
set_type
):
def
create_multi_turn_examples
(
lines
,
set_type
):
"""Creates examples for multi-turn dialogue sets."""
conv_id
=
-
1
examples
=
[]
conv_example
=
[]
index
=
0
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
4
:
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
4
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains four parts: conversation_id
\t
label
\t
caller
\t
conversation_content"
)
print
(
"data row contains four parts: conversation_id
\t
label
\t
caller
\t
conversation_content"
)
continue
tokens
=
line
conv_no
=
tokens
[
0
]
if
conv_no
!=
conv_id
and
i
!=
0
:
if
conv_no
!=
conv_id
and
i
!=
0
:
samples
=
create_dialogue_examples
(
conv_example
)
for
sample
in
samples
:
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
index
+=
1
label
=
sample
[
0
]
...
...
@@ -659,15 +694,20 @@ def create_multi_turn_examples(lines, set_type):
text_b
=
sample
[
2
]
text_c
=
sample
[
3
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
if
i
==
0
:
conv_id
=
conv_no
conv_example
.
append
(
tokens
)
if
conv_example
:
if
conv_example
:
samples
=
create_dialogue_examples
(
conv_example
)
for
sample
in
samples
:
for
sample
in
samples
:
guid
=
"%s-%s"
%
(
set_type
,
index
)
index
+=
1
label
=
sample
[
0
]
...
...
@@ -675,62 +715,67 @@ def create_multi_turn_examples(lines, set_type):
text_b
=
sample
[
2
]
text_c
=
sample
[
3
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
return
examples
def
convert_tokens
(
tokens
,
sep_id
,
tokenizer
):
def
convert_tokens
(
tokens
,
sep_id
,
tokenizer
):
"""Converts tokens to ids"""
tokens_ids
=
[]
if
not
tokens
:
if
not
tokens
:
return
tokens_ids
if
isinstance
(
tokens
,
list
):
for
text
in
tokens
:
if
isinstance
(
tokens
,
list
):
for
text
in
tokens
:
tok_text
=
tokenizer
.
tokenize
(
text
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tok_text
)
tokens_ids
.
extend
(
ids
)
tokens_ids
.
append
(
sep_id
)
tokens_ids
=
tokens_ids
[:
-
1
]
else
:
tokens_ids
=
tokens_ids
[:
-
1
]
else
:
tok_text
=
tokenizer
.
tokenize
(
tokens
)
tokens_ids
=
tokenizer
.
convert_tokens_to_ids
(
tok_text
)
return
tokens_ids
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
,
task_name
):
"""Converts a single DA `InputExample` into a single `InputFeatures`."""
label_map
=
{}
SEP
=
102
SEP
=
102
CLS
=
101
if
task_name
==
'udc'
:
if
task_name
==
'udc'
:
INNER_SEP
=
1
limit_length
=
60
elif
task_name
==
'swda'
:
elif
task_name
==
'swda'
:
INNER_SEP
=
1
limit_length
=
50
elif
task_name
==
'mrda'
:
elif
task_name
==
'mrda'
:
INNER_SEP
=
1
limit_length
=
50
elif
task_name
==
'atis_intent'
:
elif
task_name
==
'atis_intent'
:
INNER_SEP
=
-
1
limit_length
=
-
1
elif
task_name
==
'atis_slot'
:
elif
task_name
==
'atis_slot'
:
INNER_SEP
=
-
1
limit_length
=
-
1
elif
task_name
==
'dstc2'
:
elif
task_name
==
'dstc2'
:
INNER_SEP
=
1
limit_length
=
-
1
elif
task_name
==
'dstc2_asr'
:
elif
task_name
==
'dstc2_asr'
:
INNER_SEP
=
1
limit_length
=
-
1
elif
task_name
==
'multi-woz'
:
elif
task_name
==
'multi-woz'
:
INNER_SEP
=
1
limit_length
=
200
for
(
i
,
label
)
in
enumerate
(
label_list
):
for
(
i
,
label
)
in
enumerate
(
label_list
):
label_map
[
label
]
=
i
tokens_a
=
example
.
text_a
tokens_b
=
example
.
text_b
tokens_c
=
example
.
text_c
...
...
@@ -739,30 +784,36 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_b_ids
=
convert_tokens
(
tokens_b
,
INNER_SEP
,
tokenizer
)
tokens_c_ids
=
convert_tokens
(
tokens_c
,
INNER_SEP
,
tokenizer
)
if
tokens_b_ids
:
if
tokens_b_ids
:
tokens_b_ids
=
tokens_b_ids
[:
min
(
limit_length
,
len
(
tokens_b_ids
))]
else
:
else
:
if
len
(
tokens_a_ids
)
>
max_seq_length
-
2
:
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
2
:]
if
not
tokens_c_ids
:
if
len
(
tokens_a_ids
)
>
max_seq_length
-
len
(
tokens_b_ids
)
-
3
:
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
len
(
tokens_b_ids
)
+
3
:]
else
:
if
len
(
tokens_a_ids
)
+
len
(
tokens_b_ids
)
+
len
(
tokens_c_ids
)
>
max_seq_length
-
4
:
if
not
tokens_c_ids
:
if
len
(
tokens_a_ids
)
>
max_seq_length
-
len
(
tokens_b_ids
)
-
3
:
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
len
(
tokens_b_ids
)
+
3
:]
else
:
if
len
(
tokens_a_ids
)
+
len
(
tokens_b_ids
)
+
len
(
tokens_c_ids
)
>
max_seq_length
-
4
:
left_num
=
max_seq_length
-
len
(
tokens_b_ids
)
-
4
if
len
(
tokens_a_ids
)
>
len
(
tokens_c_ids
):
if
len
(
tokens_a_ids
)
>
len
(
tokens_c_ids
):
suffix_num
=
int
(
left_num
/
2
)
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
prefix_num
=
left_num
-
len
(
tokens_c_ids
)
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
else
:
if
not
tokens_a_ids
:
tokens_c_ids
=
tokens_c_ids
[
max
(
0
,
len
(
tokens_c_ids
)
-
left_num
):]
else
:
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
else
:
if
not
tokens_a_ids
:
tokens_c_ids
=
tokens_c_ids
[
max
(
0
,
len
(
tokens_c_ids
)
-
left_num
):]
else
:
prefix_num
=
int
(
left_num
/
2
)
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
suffix_num
=
left_num
-
len
(
tokens_a_ids
)
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
input_ids
=
[]
segment_ids
=
[]
...
...
@@ -772,31 +823,31 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids
.
extend
([
0
]
*
len
(
tokens_a_ids
))
input_ids
.
append
(
SEP
)
segment_ids
.
append
(
0
)
if
tokens_b_ids
:
if
tokens_b_ids
:
input_ids
.
extend
(
tokens_b_ids
)
segment_ids
.
extend
([
1
]
*
len
(
tokens_b_ids
))
input_ids
.
append
(
SEP
)
segment_ids
.
append
(
1
)
if
tokens_c_ids
:
if
tokens_c_ids
:
input_ids
.
extend
(
tokens_c_ids
)
segment_ids
.
extend
([
0
]
*
len
(
tokens_c_ids
))
input_ids
.
append
(
SEP
)
segment_ids
.
append
(
0
)
input_mask
=
[
1
]
*
len
(
input_ids
)
if
task_name
==
'atis_slot'
:
if
task_name
==
'atis_slot'
:
label_id
=
[
0
]
+
[
label_map
[
l
]
for
l
in
example
.
label
]
+
[
0
]
elif
task_name
in
[
'dstc2'
,
'dstc2_asr'
,
'multi-woz'
]:
elif
task_name
in
[
'dstc2'
,
'dstc2_asr'
,
'multi-woz'
]:
label_id_enty
=
[
label_map
[
l
]
for
l
in
example
.
label
]
label_id
=
[]
for
i
in
range
(
len
(
label_map
)):
if
i
in
label_id_enty
:
for
i
in
range
(
len
(
label_map
)):
if
i
in
label_id_enty
:
label_id
.
append
(
1
)
else
:
else
:
label_id
.
append
(
0
)
else
:
else
:
label_id
=
label_map
[
example
.
label
]
if
ex_index
<
5
:
print
(
"*** Example ***"
)
print
(
"guid: %s"
%
(
example
.
guid
))
...
...
@@ -809,7 +860,5 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
)
return
feature
return
feature
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录