Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
300be16c
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
300be16c
编写于
1月 08, 2020
作者:
P
pkpk
提交者:
GitHub
1月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop (#4175)
上级
071dc299
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
192 addition
and
143 deletion
+192
-143
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/reader.py
...ddleDialogue/dialogue_general_understanding/dgu/reader.py
+192
-143
未找到文件。
PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/reader.py
浏览文件 @
300be16c
...
...
@@ -23,8 +23,10 @@ import numpy as np
from
dgu
import
tokenization
from
dgu.batching
import
prepare_batch_data
reload
(
sys
)
sys
.
setdefaultencoding
(
'utf-8'
)
if
sys
.
version
[
0
]
==
'2'
:
reload
(
sys
)
sys
.
setdefaultencoding
(
'utf-8'
)
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
...
...
@@ -152,9 +154,9 @@ class DataProcessor(object):
if
shuffle
:
np
.
random
.
shuffle
(
examples
)
for
(
index
,
example
)
in
enumerate
(
examples
):
feature
=
self
.
convert_example
(
index
,
example
,
self
.
get_labels
(),
self
.
max_seq_len
,
self
.
tokenizer
)
feature
=
self
.
convert_example
(
index
,
example
,
self
.
get_labels
()
,
self
.
max_seq_len
,
self
.
tokenizer
)
instance
=
self
.
generate_instance
(
feature
)
yield
instance
...
...
@@ -252,17 +254,22 @@ class InputFeatures(object):
class
UDCProcessor
(
DataProcessor
):
"""Processor for the UDC data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
print
(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
print
(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
<
3
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains at least three parts: label
\t
conv1
\t
.....
\t
response"
)
print
(
"data row contains at least three parts: label
\t
conv1
\t
.....
\t
response"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
"
\t
"
.
join
(
line
[
1
:
-
1
])
text_a
=
"
\t
"
.
join
(
line
[
1
:
-
1
])
text_a
=
tokenization
.
convert_to_unicode
(
text_a
)
text_a
=
text_a
.
split
(
'
\t
'
)
text_b
=
line
[
-
1
]
...
...
@@ -302,6 +309,7 @@ class UDCProcessor(DataProcessor):
class
SWDAProcessor
(
DataProcessor
):
"""Processor for the SWDA data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
create_multi_turn_examples
(
lines
,
set_type
)
...
...
@@ -338,6 +346,7 @@ class SWDAProcessor(DataProcessor):
class
MRDAProcessor
(
DataProcessor
):
"""Processor for the MRDA data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
create_multi_turn_examples
(
lines
,
set_type
)
...
...
@@ -374,13 +383,16 @@ class MRDAProcessor(DataProcessor):
class
ATISSlotProcessor
(
DataProcessor
):
"""Processor for the ATIS Slot data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains two parts: conversation_content
\t
label1 label2 label3"
)
print
(
"data row contains two parts: conversation_content
\t
label1 label2 label3"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
line
[
0
]
...
...
@@ -423,21 +435,21 @@ class ATISSlotProcessor(DataProcessor):
class
ATISIntentProcessor
(
DataProcessor
):
"""Processor for the ATIS intent data set."""
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
2
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains two parts: label
\t
conversation_content"
)
print
(
"data row contains two parts: label
\t
conversation_content"
)
continue
guid
=
"%s-%d"
%
(
set_type
,
i
)
text_a
=
line
[
1
]
text_a
=
tokenization
.
convert_to_unicode
(
text_a
)
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
label
=
label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
label
=
label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -471,12 +483,13 @@ class ATISIntentProcessor(DataProcessor):
class
DSTC2Processor
(
DataProcessor
):
"""Processor for the DSTC2 data set."""
def
_create_turns
(
self
,
conv_example
):
"""create multi turn dataset"""
samples
=
[]
max_turns
=
20
for
i
in
range
(
len
(
conv_example
)):
conv_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
+
1
]
conv_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
+
1
]
conv_info
=
"
\1
"
.
join
([
sample
[
0
]
for
sample
in
conv_turns
])
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
return
samples
...
...
@@ -490,7 +503,9 @@ class DSTC2Processor(DataProcessor):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
3
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains three parts: conversation_content
\t
question
\1
answer
\t
state1 state2 state3......"
)
print
(
"data row contains three parts: conversation_content
\t
question
\1
answer
\t
state1 state2 state3......"
)
continue
conv_no
=
line
[
0
]
text_a
=
line
[
1
]
...
...
@@ -502,7 +517,9 @@ class DSTC2Processor(DataProcessor):
index
+=
1
history
=
sample
[
0
]
dst_label
=
sample
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
...
...
@@ -515,7 +532,9 @@ class DSTC2Processor(DataProcessor):
index
+=
1
history
=
sample
[
0
]
dst_label
=
sample
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
label
=
dst_label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -549,15 +568,17 @@ class DSTC2Processor(DataProcessor):
class
MULTIWOZProcessor
(
DataProcessor
):
"""Processor for the MULTIWOZ data set."""
def
_create_turns
(
self
,
conv_example
):
"""create multi turn dataset"""
samples
=
[]
max_turns
=
2
for
i
in
range
(
len
(
conv_example
)):
prefix_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
]
prefix_turns
=
conv_example
[
max
(
i
-
max_turns
,
0
):
i
]
conv_info
=
"
\1
"
.
join
([
turn
[
0
]
for
turn
in
prefix_turns
])
current_turns
=
conv_example
[
i
][
0
]
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
current_turns
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
samples
.
append
((
conv_info
.
split
(
'
\1
'
),
current_turns
.
split
(
'
\1
'
),
conv_example
[
i
][
1
]))
return
samples
def
_create_examples
(
self
,
lines
,
set_type
):
...
...
@@ -578,7 +599,12 @@ class MULTIWOZProcessor(DataProcessor):
history
=
sample
[
0
]
current
=
sample
[
1
]
dst_label
=
sample
[
2
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
...
...
@@ -592,7 +618,12 @@ class MULTIWOZProcessor(DataProcessor):
history
=
sample
[
0
]
current
=
sample
[
1
]
dst_label
=
sample
[
2
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
history
,
text_b
=
current
,
label
=
dst_label
))
return
examples
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -629,8 +660,10 @@ def create_dialogue_examples(conv):
samples
=
[]
for
i
in
range
(
len
(
conv
)):
cur_txt
=
"%s : %s"
%
(
conv
[
i
][
2
],
conv
[
i
][
3
])
pre_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
max
(
0
,
i
-
5
):
i
]]
suf_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
i
+
1
:
min
(
len
(
conv
),
i
+
3
)]]
pre_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
max
(
0
,
i
-
5
):
i
]]
suf_txt
=
[
"%s : %s"
%
(
c
[
2
],
c
[
3
])
for
c
in
conv
[
i
+
1
:
min
(
len
(
conv
),
i
+
3
)]
]
sample
=
[
conv
[
i
][
1
],
pre_txt
,
cur_txt
,
suf_txt
]
samples
.
append
(
sample
)
return
samples
...
...
@@ -645,7 +678,9 @@ def create_multi_turn_examples(lines, set_type):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
len
(
line
)
!=
4
:
print
(
"data format error: %s"
%
"
\t
"
.
join
(
line
))
print
(
"data row contains four parts: conversation_id
\t
label
\t
caller
\t
conversation_content"
)
print
(
"data row contains four parts: conversation_id
\t
label
\t
caller
\t
conversation_content"
)
continue
tokens
=
line
conv_no
=
tokens
[
0
]
...
...
@@ -659,7 +694,12 @@ def create_multi_turn_examples(lines, set_type):
text_b
=
sample
[
2
]
text_c
=
sample
[
3
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
conv_example
=
[]
conv_id
=
conv_no
if
i
==
0
:
...
...
@@ -675,7 +715,12 @@ def create_multi_turn_examples(lines, set_type):
text_b
=
sample
[
2
]
text_c
=
sample
[
3
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
text_c
=
text_c
,
label
=
label
))
return
examples
...
...
@@ -690,7 +735,7 @@ def convert_tokens(tokens, sep_id, tokenizer):
ids
=
tokenizer
.
convert_tokens_to_ids
(
tok_text
)
tokens_ids
.
extend
(
ids
)
tokens_ids
.
append
(
sep_id
)
tokens_ids
=
tokens_ids
[:
-
1
]
tokens_ids
=
tokens_ids
[:
-
1
]
else
:
tok_text
=
tokenizer
.
tokenize
(
tokens
)
tokens_ids
=
tokenizer
.
convert_tokens_to_ids
(
tok_text
)
...
...
@@ -746,23 +791,29 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
2
:]
if
not
tokens_c_ids
:
if
len
(
tokens_a_ids
)
>
max_seq_length
-
len
(
tokens_b_ids
)
-
3
:
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
len
(
tokens_b_ids
)
+
3
:]
tokens_a_ids
=
tokens_a_ids
[
len
(
tokens_a_ids
)
-
max_seq_length
+
len
(
tokens_b_ids
)
+
3
:]
else
:
if
len
(
tokens_a_ids
)
+
len
(
tokens_b_ids
)
+
len
(
tokens_c_ids
)
>
max_seq_length
-
4
:
if
len
(
tokens_a_ids
)
+
len
(
tokens_b_ids
)
+
len
(
tokens_c_ids
)
>
max_seq_length
-
4
:
left_num
=
max_seq_length
-
len
(
tokens_b_ids
)
-
4
if
len
(
tokens_a_ids
)
>
len
(
tokens_c_ids
):
suffix_num
=
int
(
left_num
/
2
)
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
prefix_num
=
left_num
-
len
(
tokens_c_ids
)
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
else
:
if
not
tokens_a_ids
:
tokens_c_ids
=
tokens_c_ids
[
max
(
0
,
len
(
tokens_c_ids
)
-
left_num
):]
tokens_c_ids
=
tokens_c_ids
[
max
(
0
,
len
(
tokens_c_ids
)
-
left_num
):]
else
:
prefix_num
=
int
(
left_num
/
2
)
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
tokens_a_ids
=
tokens_a_ids
[
max
(
0
,
len
(
tokens_a_ids
)
-
prefix_num
):]
suffix_num
=
left_num
-
len
(
tokens_a_ids
)
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
tokens_c_ids
=
tokens_c_ids
[:
min
(
len
(
tokens_c_ids
),
suffix_num
)]
input_ids
=
[]
segment_ids
=
[]
...
...
@@ -811,5 +862,3 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id
=
label_id
)
return
feature
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录