Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
Models
提交
e9286a5e
M
Models
项目概览
MegEngine 天元
/
Models
通知
6
Star
3
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e9286a5e
编写于
6月 12, 2020
作者:
C
ChenXinhao
提交者:
GitHub
6月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12 from ChenXinhao/master
fix(bert) fix pylint error
上级
10cff870
ae380a84
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
49 addition
and
49 deletion
+49
-49
official/nlp/bert/config_args.py
official/nlp/bert/config_args.py
+0
-0
official/nlp/bert/model.py
official/nlp/bert/model.py
+6
-7
official/nlp/bert/mrpc_dataset.py
official/nlp/bert/mrpc_dataset.py
+9
-9
official/nlp/bert/test.py
official/nlp/bert/test.py
+5
-5
official/nlp/bert/tokenization.py
official/nlp/bert/tokenization.py
+23
-22
official/nlp/bert/train.py
official/nlp/bert/train.py
+6
-6
未找到文件。
official/nlp/bert/config.py
→
official/nlp/bert/config
_args
.py
浏览文件 @
e9286a5e
文件已移动
official/nlp/bert/model.py
浏览文件 @
e9286a5e
...
@@ -23,7 +23,6 @@ import copy
...
@@ -23,7 +23,6 @@ import copy
import
json
import
json
import
math
import
math
import
os
import
os
import
sys
import
urllib
import
urllib
import
urllib.request
import
urllib.request
from
io
import
open
from
io
import
open
...
@@ -39,7 +38,7 @@ from megengine.module.activation import Softmax
...
@@ -39,7 +38,7 @@ from megengine.module.activation import Softmax
def
transpose
(
inp
,
a
,
b
):
def
transpose
(
inp
,
a
,
b
):
cur_shape
=
[
i
for
i
in
range
(
0
,
len
(
inp
.
shape
))]
cur_shape
=
list
(
range
(
0
,
len
(
inp
.
shape
)))
cur_shape
[
a
],
cur_shape
[
b
]
=
cur_shape
[
b
],
cur_shape
[
a
]
cur_shape
[
a
],
cur_shape
[
b
]
=
cur_shape
[
b
],
cur_shape
[
a
]
return
inp
.
dimshuffle
(
*
cur_shape
)
return
inp
.
dimshuffle
(
*
cur_shape
)
...
@@ -84,7 +83,7 @@ def gelu(x):
...
@@ -84,7 +83,7 @@ def gelu(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
F
.
relu
}
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
F
.
relu
}
class
BertConfig
(
object
)
:
class
BertConfig
:
"""Configuration class to store the configuration of a `BertModel`.
"""Configuration class to store the configuration of a `BertModel`.
"""
"""
...
@@ -441,6 +440,7 @@ class BertModel(Module):
...
@@ -441,6 +440,7 @@ class BertModel(Module):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
pooler
=
BertPooler
(
config
)
self
.
pooler
=
BertPooler
(
config
)
...
@@ -537,6 +537,7 @@ class BertForSequenceClassification(Module):
...
@@ -537,6 +537,7 @@ class BertForSequenceClassification(Module):
"""
"""
def
__init__
(
self
,
config
,
num_labels
,
bert
=
None
):
def
__init__
(
self
,
config
,
num_labels
,
bert
=
None
):
super
().
__init__
()
if
bert
is
None
:
if
bert
is
None
:
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
)
else
:
else
:
...
@@ -577,10 +578,8 @@ MODEL_NAME = {
...
@@ -577,10 +578,8 @@ MODEL_NAME = {
def
download_file
(
url
,
filename
):
def
download_file
(
url
,
filename
):
try
:
# urllib.URLopener().retrieve(url, filename)
urllib
.
URLopener
().
retrieve
(
url
,
filename
)
urllib
.
request
.
urlretrieve
(
url
,
filename
)
except
:
urllib
.
request
.
urlretrieve
(
url
,
filename
)
def
create_hub_bert
(
model_name
,
pretrained
):
def
create_hub_bert
(
model_name
,
pretrained
):
...
...
official/nlp/bert/mrpc_dataset.py
浏览文件 @
e9286a5e
...
@@ -20,7 +20,7 @@ from tokenization import BertTokenizer
...
@@ -20,7 +20,7 @@ from tokenization import BertTokenizer
logger
=
mge
.
get_logger
(
__name__
)
logger
=
mge
.
get_logger
(
__name__
)
class
DataProcessor
(
object
)
:
class
DataProcessor
:
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
...
@@ -46,7 +46,7 @@ class DataProcessor(object):
...
@@ -46,7 +46,7 @@ class DataProcessor(object):
return
lines
return
lines
class
InputFeatures
(
object
)
:
class
InputFeatures
:
"""A single set of features of data."""
"""A single set of features of data."""
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
...
@@ -56,7 +56,7 @@ class InputFeatures(object):
...
@@ -56,7 +56,7 @@ class InputFeatures(object):
self
.
label_id
=
label_id
self
.
label_id
=
label_id
class
InputExample
(
object
)
:
class
InputExample
:
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
...
@@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
...
@@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
label_id
=
label_map
[
example
.
label
]
label_id
=
label_map
[
example
.
label
]
if
ex_index
<
0
:
if
ex_index
<
0
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"guid:
%s"
%
(
example
.
guid
))
logger
.
info
(
"guid:
{}"
.
format
(
example
.
guid
))
logger
.
info
(
"tokens:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
tokens
]
))
logger
.
info
(
"tokens:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
tokens
])
))
logger
.
info
(
"input_ids:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]
))
logger
.
info
(
"input_ids:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
input_ids
])
))
logger
.
info
(
"input_mask:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]
))
logger
.
info
(
"input_mask:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
input_mask
])
))
logger
.
info
(
"segment_ids:
%s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]
))
logger
.
info
(
"segment_ids:
{}"
.
format
(
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
])
))
logger
.
info
(
"label:
%s (id = %d)"
%
(
example
.
label
,
label_id
))
logger
.
info
(
"label:
{} (id = {})"
.
format
(
example
.
label
,
label_id
))
features
.
append
(
features
.
append
(
InputFeatures
(
InputFeatures
(
...
...
official/nlp/bert/test.py
浏览文件 @
e9286a5e
...
@@ -12,16 +12,16 @@ import megengine.functional as F
...
@@ -12,16 +12,16 @@ import megengine.functional as F
from
megengine.jit
import
trace
from
megengine.jit
import
trace
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
config
import
get_args
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
mrpc_dataset
import
MRPCDataset
from
mrpc_dataset
import
MRPCDataset
# pylint: disable=import-outside-toplevel
args
=
get_args
()
import
config_args
args
=
config_args
.
get_args
()
logger
=
mge
.
get_logger
(
__name__
)
logger
=
mge
.
get_logger
(
__name__
)
@
trace
(
symbolic
=
True
)
@
trace
(
symbolic
=
True
)
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
opt
=
None
,
net
=
None
):
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
net
=
None
):
net
.
eval
()
net
.
eval
()
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
,
loss
=
results
logits
,
loss
=
results
...
@@ -39,7 +39,7 @@ def eval(dataloader, net):
...
@@ -39,7 +39,7 @@ def eval(dataloader, net):
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
mge
.
tensor
(
t
)
for
t
in
batch
)
)
...
...
official/nlp/bert/tokenization.py
浏览文件 @
e9286a5e
...
@@ -22,7 +22,7 @@ import os
...
@@ -22,7 +22,7 @@ import os
import
unicodedata
import
unicodedata
from
io
import
open
from
io
import
open
import
megengine
as
megengine
import
megengine
logger
=
megengine
.
get_logger
(
__name__
)
logger
=
megengine
.
get_logger
(
__name__
)
...
@@ -54,7 +54,7 @@ def whitespace_tokenize(text):
...
@@ -54,7 +54,7 @@ def whitespace_tokenize(text):
return
tokens
return
tokens
class
BertTokenizer
(
object
)
:
class
BertTokenizer
:
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def
__init__
(
def
__init__
(
...
@@ -150,7 +150,7 @@ class BertTokenizer(object):
...
@@ -150,7 +150,7 @@ class BertTokenizer(object):
return
vocab_file
return
vocab_file
class
BasicTokenizer
(
object
)
:
class
BasicTokenizer
:
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
def
__init__
(
...
@@ -243,18 +243,19 @@ class BasicTokenizer(object):
...
@@ -243,18 +243,19 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
# like the all of the other languages.
if
(
cp_range
=
[
(
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
(
0x4E00
,
0x9FFF
),
or
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
#
(
0x3400
,
0x4DBF
),
or
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
#
(
0x20000
,
0x2A6DF
),
or
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
#
(
0x2A700
,
0x2B73F
),
or
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
#
(
0x2B740
,
0x2B81F
),
or
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
#
(
0x2B820
,
0x2CEAF
),
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
(
0xF900
,
0xFAFF
),
or
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)
#
(
0x2F800
,
0x2FA1F
),
):
#
]
return
True
for
min_cp
,
max_cp
in
cp_range
:
if
min_cp
<=
cp
<=
max_cp
:
return
True
return
False
return
False
def
_clean_text
(
self
,
text
):
def
_clean_text
(
self
,
text
):
...
@@ -271,7 +272,7 @@ class BasicTokenizer(object):
...
@@ -271,7 +272,7 @@ class BasicTokenizer(object):
return
""
.
join
(
output
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
)
:
class
WordpieceTokenizer
:
"""Runs WordPiece tokenization."""
"""Runs WordPiece tokenization."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
...
@@ -335,7 +336,7 @@ def _is_whitespace(char):
...
@@ -335,7 +336,7 @@ def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
in
(
" "
,
"
\t
"
,
"
\n
"
,
"
\r
"
)
:
return
True
return
True
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
if
cat
==
"Zs"
:
...
@@ -347,7 +348,7 @@ def _is_control(char):
...
@@ -347,7 +348,7 @@ def _is_control(char):
"""Checks whether `chars` is a control character."""
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# These are technically control characters but we count them as whitespace
# characters.
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
in
(
"
\t
"
,
"
\n
"
,
"
\r
"
)
:
return
False
return
False
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
if
cat
.
startswith
(
"C"
):
...
@@ -363,10 +364,10 @@ def _is_punctuation(char):
...
@@ -363,10 +364,10 @@ def _is_punctuation(char):
# Punctuation class but we treat them as punctuation anyways, for
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
# consistency.
if
(
if
(
(
cp
>=
33
and
cp
<=
47
)
(
33
<=
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
58
<=
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
91
<=
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)
or
(
123
<=
cp
<=
126
)
):
):
return
True
return
True
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
...
...
official/nlp/bert/train.py
浏览文件 @
e9286a5e
...
@@ -13,16 +13,16 @@ import megengine.optimizer as optim
...
@@ -13,16 +13,16 @@ import megengine.optimizer as optim
from
megengine.jit
import
trace
from
megengine.jit
import
trace
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
config
import
get_args
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
model
import
BertForSequenceClassification
,
create_hub_bert
from
mrpc_dataset
import
MRPCDataset
from
mrpc_dataset
import
MRPCDataset
# pylint: disable=import-outside-toplevel
args
=
get_args
()
import
config_args
args
=
config_args
.
get_args
()
logger
=
mge
.
get_logger
(
__name__
)
logger
=
mge
.
get_logger
(
__name__
)
@
trace
(
symbolic
=
True
)
@
trace
(
symbolic
=
True
)
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
opt
=
None
,
net
=
None
):
def
net_eval
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
,
net
=
None
):
net
.
eval
()
net
.
eval
()
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
results
=
net
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
,
loss
=
results
logits
,
loss
=
results
...
@@ -49,7 +49,7 @@ def eval(dataloader, net):
...
@@ -49,7 +49,7 @@ def eval(dataloader, net):
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
mge
.
tensor
(
t
)
for
t
in
batch
)
)
...
@@ -79,7 +79,7 @@ def train(dataloader, net, opt):
...
@@ -79,7 +79,7 @@ def train(dataloader, net, opt):
logger
.
info
(
"batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
"batch size = %d"
,
args
.
train_batch_size
)
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
sum_loss
,
sum_accuracy
,
total_steps
,
total_examples
=
0
,
0
,
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
for
_
,
batch
in
enumerate
(
tqdm
(
dataloader
,
desc
=
"Iteration"
)):
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
tuple
(
mge
.
tensor
(
t
)
for
t
in
batch
mge
.
tensor
(
t
)
for
t
in
batch
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录