提交 ae380a84 编写于 作者: C Chen xinhao

fix(bert) fix pylint error

上级 9766a399
...@@ -23,7 +23,6 @@ import copy ...@@ -23,7 +23,6 @@ import copy
import json import json
import math import math
import os import os
import sys
import urllib import urllib
import urllib.request import urllib.request
from io import open from io import open
...@@ -39,7 +38,7 @@ from megengine.module.activation import Softmax ...@@ -39,7 +38,7 @@ from megengine.module.activation import Softmax
def transpose(inp, a, b): def transpose(inp, a, b):
cur_shape = [i for i in range(0, len(inp.shape))] cur_shape = list(range(0, len(inp.shape)))
cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a] cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a]
return inp.dimshuffle(*cur_shape) return inp.dimshuffle(*cur_shape)
...@@ -84,7 +83,7 @@ def gelu(x): ...@@ -84,7 +83,7 @@ def gelu(x):
ACT2FN = {"gelu": gelu, "relu": F.relu} ACT2FN = {"gelu": gelu, "relu": F.relu}
class BertConfig(object): class BertConfig:
"""Configuration class to store the configuration of a `BertModel`. """Configuration class to store the configuration of a `BertModel`.
""" """
...@@ -441,6 +440,7 @@ class BertModel(Module): ...@@ -441,6 +440,7 @@ class BertModel(Module):
""" """
def __init__(self, config): def __init__(self, config):
super().__init__()
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
...@@ -537,6 +537,7 @@ class BertForSequenceClassification(Module): ...@@ -537,6 +537,7 @@ class BertForSequenceClassification(Module):
""" """
def __init__(self, config, num_labels, bert=None): def __init__(self, config, num_labels, bert=None):
super().__init__()
if bert is None: if bert is None:
self.bert = BertModel(config) self.bert = BertModel(config)
else: else:
...@@ -577,9 +578,7 @@ MODEL_NAME = { ...@@ -577,9 +578,7 @@ MODEL_NAME = {
def download_file(url, filename): def download_file(url, filename):
try: # urllib.URLopener().retrieve(url, filename)
urllib.URLopener().retrieve(url, filename)
except:
urllib.request.urlretrieve(url, filename) urllib.request.urlretrieve(url, filename)
......
...@@ -20,7 +20,7 @@ from tokenization import BertTokenizer ...@@ -20,7 +20,7 @@ from tokenization import BertTokenizer
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
class DataProcessor(object): class DataProcessor:
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -46,7 +46,7 @@ class DataProcessor(object): ...@@ -46,7 +46,7 @@ class DataProcessor(object):
return lines return lines
class InputFeatures(object): class InputFeatures:
"""A single set of features of data.""" """A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id): def __init__(self, input_ids, input_mask, segment_ids, label_id):
...@@ -56,7 +56,7 @@ class InputFeatures(object): ...@@ -56,7 +56,7 @@ class InputFeatures(object):
self.label_id = label_id self.label_id = label_id
class InputExample(object): class InputExample:
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None): def __init__(self, guid, text_a, text_b=None, label=None):
...@@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer ...@@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
label_id = label_map[example.label] label_id = label_map[example.label]
if ex_index < 0: if ex_index < 0:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info("guid: {}".format(example.guid))
logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) logger.info("tokens: {}".format(" ".join([str(x) for x in tokens])))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: {}".format(" ".join([str(x) for x in input_ids])))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info("input_mask: {}".format(" ".join([str(x) for x in input_mask])))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) logger.info("segment_ids: {}".format(" ".join([str(x) for x in segment_ids])))
logger.info("label: %s (id = %d)" % (example.label, label_id)) logger.info("label: {} (id = {})".format(example.label, label_id))
features.append( features.append(
InputFeatures( InputFeatures(
......
...@@ -12,16 +12,16 @@ import megengine.functional as F ...@@ -12,16 +12,16 @@ import megengine.functional as F
from megengine.jit import trace from megengine.jit import trace
from tqdm import tqdm from tqdm import tqdm
from config import get_args
from model import BertForSequenceClassification, create_hub_bert from model import BertForSequenceClassification, create_hub_bert
from mrpc_dataset import MRPCDataset from mrpc_dataset import MRPCDataset
# pylint: disable=import-outside-toplevel
args = get_args() import config_args
args = config_args.get_args()
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
@trace(symbolic=True) @trace(symbolic=True)
def net_eval(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None): def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
net.eval() net.eval()
results = net(input_ids, segment_ids, input_mask, label_ids) results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results logits, loss = results
...@@ -39,7 +39,7 @@ def eval(dataloader, net): ...@@ -39,7 +39,7 @@ def eval(dataloader, net):
sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0
for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, label_ids = tuple( input_ids, input_mask, segment_ids, label_ids = tuple(
mge.tensor(t) for t in batch mge.tensor(t) for t in batch
) )
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import unicodedata import unicodedata
from io import open from io import open
import megengine as megengine import megengine
logger = megengine.get_logger(__name__) logger = megengine.get_logger(__name__)
...@@ -54,7 +54,7 @@ def whitespace_tokenize(text): ...@@ -54,7 +54,7 @@ def whitespace_tokenize(text):
return tokens return tokens
class BertTokenizer(object): class BertTokenizer:
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__( def __init__(
...@@ -150,7 +150,7 @@ class BertTokenizer(object): ...@@ -150,7 +150,7 @@ class BertTokenizer(object):
return vocab_file return vocab_file
class BasicTokenizer(object): class BasicTokenizer:
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__( def __init__(
...@@ -243,18 +243,19 @@ class BasicTokenizer(object): ...@@ -243,18 +243,19 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write # as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled # space-separated words, so they are not treated specially and handled
# like the all of the other languages. # like the all of the other languages.
if ( cp_range = [
(cp >= 0x4E00 and cp <= 0x9FFF) (0x4E00, 0x9FFF),
or (cp >= 0x3400 and cp <= 0x4DBF) # (0x3400, 0x4DBF),
or (cp >= 0x20000 and cp <= 0x2A6DF) # (0x20000, 0x2A6DF),
or (cp >= 0x2A700 and cp <= 0x2B73F) # (0x2A700, 0x2B73F),
or (cp >= 0x2B740 and cp <= 0x2B81F) # (0x2B740, 0x2B81F),
or (cp >= 0x2B820 and cp <= 0x2CEAF) # (0x2B820, 0x2CEAF),
or (cp >= 0xF900 and cp <= 0xFAFF) (0xF900, 0xFAFF),
or (cp >= 0x2F800 and cp <= 0x2FA1F) # (0x2F800, 0x2FA1F),
): # ]
for min_cp, max_cp in cp_range:
if min_cp <= cp <= max_cp:
return True return True
return False return False
def _clean_text(self, text): def _clean_text(self, text):
...@@ -271,7 +272,7 @@ class BasicTokenizer(object): ...@@ -271,7 +272,7 @@ class BasicTokenizer(object):
return "".join(output) return "".join(output)
class WordpieceTokenizer(object): class WordpieceTokenizer:
"""Runs WordPiece tokenization.""" """Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
...@@ -335,7 +336,7 @@ def _is_whitespace(char): ...@@ -335,7 +336,7 @@ def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character.""" """Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them # \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such. # as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r": if char in (" ", "\t", "\n", "\r"):
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat == "Zs": if cat == "Zs":
...@@ -347,7 +348,7 @@ def _is_control(char): ...@@ -347,7 +348,7 @@ def _is_control(char):
"""Checks whether `chars` is a control character.""" """Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace # These are technically control characters but we count them as whitespace
# characters. # characters.
if char == "\t" or char == "\n" or char == "\r": if char in ("\t", "\n", "\r"):
return False return False
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat.startswith("C"): if cat.startswith("C"):
...@@ -363,10 +364,10 @@ def _is_punctuation(char): ...@@ -363,10 +364,10 @@ def _is_punctuation(char):
# Punctuation class but we treat them as punctuation anyways, for # Punctuation class but we treat them as punctuation anyways, for
# consistency. # consistency.
if ( if (
(cp >= 33 and cp <= 47) (33 <= cp <= 47)
or (cp >= 58 and cp <= 64) or (58 <= cp <= 64)
or (cp >= 91 and cp <= 96) or (91 <= cp <= 96)
or (cp >= 123 and cp <= 126) or (123 <= cp <= 126)
): ):
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
......
...@@ -13,16 +13,16 @@ import megengine.optimizer as optim ...@@ -13,16 +13,16 @@ import megengine.optimizer as optim
from megengine.jit import trace from megengine.jit import trace
from tqdm import tqdm from tqdm import tqdm
from config import get_args
from model import BertForSequenceClassification, create_hub_bert from model import BertForSequenceClassification, create_hub_bert
from mrpc_dataset import MRPCDataset from mrpc_dataset import MRPCDataset
# pylint: disable=import-outside-toplevel
args = get_args() import config_args
args = config_args.get_args()
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
@trace(symbolic=True) @trace(symbolic=True)
def net_eval(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None): def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
net.eval() net.eval()
results = net(input_ids, segment_ids, input_mask, label_ids) results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results logits, loss = results
...@@ -49,7 +49,7 @@ def eval(dataloader, net): ...@@ -49,7 +49,7 @@ def eval(dataloader, net):
sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0
for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, label_ids = tuple( input_ids, input_mask, segment_ids, label_ids = tuple(
mge.tensor(t) for t in batch mge.tensor(t) for t in batch
) )
...@@ -79,7 +79,7 @@ def train(dataloader, net, opt): ...@@ -79,7 +79,7 @@ def train(dataloader, net, opt):
logger.info("batch size = %d", args.train_batch_size) logger.info("batch size = %d", args.train_batch_size)
sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0
for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, label_ids = tuple( input_ids, input_mask, segment_ids, label_ids = tuple(
mge.tensor(t) for t in batch mge.tensor(t) for t in batch
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册