Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
Models
提交
ae380a84
M
Models
项目概览
MegEngine 天元
/
Models
通知
6
Star
3
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae380a84
编写于
4月 24, 2020
作者:
C
Chen xinhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(bert) fix pylint error
上级
9766a399
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
49 addition
and
49 deletion
+49
-49
official/nlp/bert/config_args.py
official/nlp/bert/config_args.py
+0
-0
official/nlp/bert/model.py
official/nlp/bert/model.py
+6
-7
official/nlp/bert/mrpc_dataset.py
official/nlp/bert/mrpc_dataset.py
+9
-9
official/nlp/bert/test.py
official/nlp/bert/test.py
+5
-5
official/nlp/bert/tokenization.py
official/nlp/bert/tokenization.py
+23
-22
official/nlp/bert/train.py
official/nlp/bert/train.py
+6
-6
未找到文件。
official/nlp/bert/config.py
→
official/nlp/bert/config
_args
.py
浏览文件 @
ae380a84
文件已移动
official/nlp/bert/model.py
浏览文件 @
ae380a84
...
...
@@ -23,7 +23,6 @@ import copy
import
json
import
math
import
os
import
sys
import
urllib
import
urllib.request
from
io
import
open
...
...
@@ -39,7 +38,7 @@ from megengine.module.activation import Softmax
def
transpose
(
inp
,
a
,
b
):
cur_shape
=
[
i
for
i
in
range
(
0
,
len
(
inp
.
shape
))]
cur_shape
=
list
(
range
(
0
,
len
(
inp
.
shape
)))
cur_shape
[
a
],
cur_shape
[
b
]
=
cur_shape
[
b
],
cur_shape
[
a
]
return
inp
.
dimshuffle
(
*
cur_shape
)
...
...
@@ -84,7 +83,7 @@ def gelu(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
F
.
relu
}
class
BertConfig
(
object
)
:
class
BertConfig
:
"""Configuration class to store the configuration of a `BertModel`.
"""
...
...
@@ -441,6 +440,7 @@ class BertModel(Module):
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
pooler
=
BertPooler
(
config
)
...
...
@@ -537,6 +537,7 @@ class BertForSequenceClassification(Module):
"""
def
__init__
(
self
,
config
,
num_labels
,
bert
=
None
):
super
().
__init__
()
if
bert
is
None
:
self
.
bert
=
BertModel
(
config
)
else
:
...
...
@@ -577,9 +578,7 @@ MODEL_NAME = {
def
download_file
(
url
,
filename
):
try
:
urllib
.
URLopener
().
retrieve
(
url
,
filename
)
except
:
# urllib.URLopener().retrieve(url, filename)
urllib
.
request
.
urlretrieve
(
url
,
filename
)
...
...
official/nlp/bert/mrpc_dataset.py
浏览文件 @
ae380a84
...
...
@@ -20,7 +20,7 @@ from tokenization import BertTokenizer
logger
=
mge
.
get_logger
(
__name__
)
class
DataProcessor
(
object
)
:
class
DataProcessor
:
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
...
...
@@ -46,7 +46,7 @@ class DataProcessor(object):
return
lines
class
InputFeatures
(
object
)
:
class
InputFeatures
:
"""A single set of features of data."""
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
...
...
@@ -56,7 +56,7 @@ class InputFeatures(object):
self
.
label_id
=
label_id
class
InputExample
(
object
)
:
class
InputExample
:
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
...
...
@@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
label_id
=
label_map
[
example
.
label
]
if
ex_index
<
0
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid:
%s"
%
(
example
.
guid
))
logger
.
info
(
"tokens:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
tokens
]
))
logger
.
info
(
"input_ids:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]
))
logger
.
info
(
"input_mask:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]
))
logger
.
info
(
"segment_ids:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]
))
logger
.
info
(
"label:
%s (id = %d)"
%
(
example
.
label
,
label_id
))
logger
.
info
(
"guid:
{}"
.
format
(
example
.
guid
))
logger
.
info
(
"tokens:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
tokens
])
))
logger
.
info
(
"input_ids:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
input_ids
])
))
logger
.
info
(
"input_mask:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
input_mask
])
))
logger
.
info
(
"segment_ids:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
])
))
logger
.
info
(
"label:
{} (id = {})"
.
format
(
example
.
label
,
label_id
))
features
.
append
(
InputFeatures
(
...
...
official/nlp/bert/test.py
浏览文件 @
ae380a84
...
...
@@ -12,16 +12,16 @@ import megengine.functional as F
from
megengine.jit
import
trace
from
tqdm
import
tqdm
from
config
import
get_args
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
mrpc_dataset
import
MRPCDataset
args
=
get_args
()
# pylint: disable=import-outside-toplevel
import
config_args
args
=
config_args
.
get_args
()
logger
=
mge
.
get_logger
(
__name__
)
@
trace
(
symbolic
=
True
)
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
opt
=
None
,
net
=
None
):
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
net
=
None
):
net
.
eval
()
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
,
loss
=
results
...
...
@@ -39,7 +39,7 @@ def eval(dataloader, net):
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
)
...
...
official/nlp/bert/tokenization.py
浏览文件 @
ae380a84
...
...
@@ -22,7 +22,7 @@ import os
import
unicodedata
from
io
import
open
import
megengine
as
megengine
import
megengine
logger
=
megengine
.
get_logger
(
__name__
)
...
...
@@ -54,7 +54,7 @@ def whitespace_tokenize(text):
return
tokens
class
BertTokenizer
(
object
)
:
class
BertTokenizer
:
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def
__init__
(
...
...
@@ -150,7 +150,7 @@ class BertTokenizer(object):
return
vocab_file
class
BasicTokenizer
(
object
)
:
class
BasicTokenizer
:
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
...
...
@@ -243,18 +243,19 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
(
(
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
#
or
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
#
or
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
#
or
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
#
or
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
#
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)
#
):
#
cp_range
=
[
(
0x4E00
,
0x9FFF
),
(
0x3400
,
0x4DBF
),
(
0x20000
,
0x2A6DF
),
(
0x2A700
,
0x2B73F
),
(
0x2B740
,
0x2B81F
),
(
0x2B820
,
0x2CEAF
),
(
0xF900
,
0xFAFF
),
(
0x2F800
,
0x2FA1F
),
]
for
min_cp
,
max_cp
in
cp_range
:
if
min_cp
<=
cp
<=
max_cp
:
return
True
return
False
def
_clean_text
(
self
,
text
):
...
...
@@ -271,7 +272,7 @@ class BasicTokenizer(object):
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
)
:
class
WordpieceTokenizer
:
"""Runs WordPiece tokenization."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
...
...
@@ -335,7 +336,7 @@ def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
in
(
" "
,
"
\t
"
,
"
\n
"
,
"
\r
"
)
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
...
...
@@ -347,7 +348,7 @@ def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
in
(
"
\t
"
,
"
\n
"
,
"
\r
"
)
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
...
...
@@ -363,10 +364,10 @@ def _is_punctuation(char):
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
(
(
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)
(
33
<=
cp
<=
47
)
or
(
58
<=
cp
<=
64
)
or
(
91
<=
cp
<=
96
)
or
(
123
<=
cp
<=
126
)
):
return
True
cat
=
unicodedata
.
category
(
char
)
...
...
official/nlp/bert/train.py
浏览文件 @
ae380a84
...
...
@@ -13,16 +13,16 @@ import megengine.optimizer as optim
from
megengine.jit
import
trace
from
tqdm
import
tqdm
from
config
import
get_args
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
mrpc_dataset
import
MRPCDataset
args
=
get_args
()
# pylint: disable=import-outside-toplevel
import
config_args
args
=
config_args
.
get_args
()
logger
=
mge
.
get_logger
(
__name__
)
@
trace
(
symbolic
=
True
)
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
opt
=
None
,
net
=
None
):
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
net
=
None
):
net
.
eval
()
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
,
loss
=
results
...
...
@@ -49,7 +49,7 @@ def eval(dataloader, net):
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
)
...
...
@@ -79,7 +79,7 @@ def train(dataloader, net, opt):
logger
.
info
(
"batch size = %d"
,
args
.
train_batch_size
)
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录