From ae380a8438937a641c3c7348c47e3fca16800b45 Mon Sep 17 00:00:00 2001 From: Chen xinhao Date: Fri, 24 Apr 2020 15:27:32 +0800 Subject: [PATCH] fix(bert) fix pylint error --- .../nlp/bert/{config.py => config_args.py} | 0 official/nlp/bert/model.py | 13 +++--- official/nlp/bert/mrpc_dataset.py | 18 ++++---- official/nlp/bert/test.py | 10 ++--- official/nlp/bert/tokenization.py | 45 ++++++++++--------- official/nlp/bert/train.py | 12 ++--- 6 files changed, 49 insertions(+), 49 deletions(-) rename official/nlp/bert/{config.py => config_args.py} (100%) diff --git a/official/nlp/bert/config.py b/official/nlp/bert/config_args.py similarity index 100% rename from official/nlp/bert/config.py rename to official/nlp/bert/config_args.py diff --git a/official/nlp/bert/model.py b/official/nlp/bert/model.py index 1e65929..879879f 100644 --- a/official/nlp/bert/model.py +++ b/official/nlp/bert/model.py @@ -23,7 +23,6 @@ import copy import json import math import os -import sys import urllib import urllib.request from io import open @@ -39,7 +38,7 @@ from megengine.module.activation import Softmax def transpose(inp, a, b): - cur_shape = [i for i in range(0, len(inp.shape))] + cur_shape = list(range(0, len(inp.shape))) cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a] return inp.dimshuffle(*cur_shape) @@ -84,7 +83,7 @@ def gelu(x): ACT2FN = {"gelu": gelu, "relu": F.relu} -class BertConfig(object): +class BertConfig: """Configuration class to store the configuration of a `BertModel`. """ @@ -441,6 +440,7 @@ class BertModel(Module): """ def __init__(self, config): + super().__init__() self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) @@ -537,6 +537,7 @@ class BertForSequenceClassification(Module): """ def __init__(self, config, num_labels, bert=None): + super().__init__() if bert is None: self.bert = BertModel(config) else: @@ -577,10 +578,8 @@ MODEL_NAME = { def download_file(url, filename): - try: - urllib.URLopener().retrieve(url, filename) - except: - urllib.request.urlretrieve(url, filename) + # urllib.URLopener().retrieve(url, filename) + urllib.request.urlretrieve(url, filename) def create_hub_bert(model_name, pretrained): diff --git a/official/nlp/bert/mrpc_dataset.py b/official/nlp/bert/mrpc_dataset.py index f3042a9..4839781 100644 --- a/official/nlp/bert/mrpc_dataset.py +++ b/official/nlp/bert/mrpc_dataset.py @@ -20,7 +20,7 @@ from tokenization import BertTokenizer logger = mge.get_logger(__name__) -class DataProcessor(object): +class DataProcessor: """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): @@ -46,7 +46,7 @@ class DataProcessor(object): return lines -class InputFeatures(object): +class InputFeatures: """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_id): @@ -56,7 +56,7 @@ class InputFeatures(object): self.label_id = label_id -class InputExample(object): +class InputExample: """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): @@ -195,12 +195,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer label_id = label_map[example.label] if ex_index < 0: logger.info("*** Example ***") - logger.info("guid: %s" % (example.guid)) - logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) - logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) - logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) - logger.info("label: %s (id = %d)" % (example.label, label_id)) + logger.info("guid: {}".format(example.guid)) + logger.info("tokens: {}".format(" ".join([str(x) for x in tokens]))) + logger.info("input_ids: {}".format(" ".join([str(x) for x in input_ids]))) + logger.info("input_mask: {}".format(" ".join([str(x) for x in input_mask]))) + logger.info("segment_ids: {}".format(" ".join([str(x) for x in segment_ids]))) + logger.info("label: {} (id = {})".format(example.label, label_id)) features.append( InputFeatures( diff --git a/official/nlp/bert/test.py b/official/nlp/bert/test.py index ebb7ad4..924d0ad 100644 --- a/official/nlp/bert/test.py +++ b/official/nlp/bert/test.py @@ -12,16 +12,16 @@ import megengine.functional as F from megengine.jit import trace from tqdm import tqdm -from config import get_args from model import BertForSequenceClassification, create_hub_bert from mrpc_dataset import MRPCDataset - -args = get_args() +# pylint: disable=import-outside-toplevel +import config_args +args = config_args.get_args() logger = mge.get_logger(__name__) @trace(symbolic=True) -def net_eval(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None): +def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None): net.eval() results = net(input_ids, segment_ids, input_mask, label_ids) logits, loss = results @@ -39,7 +39,7 @@ def eval(dataloader, net): sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 - for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): + for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch ) diff --git a/official/nlp/bert/tokenization.py b/official/nlp/bert/tokenization.py index 1ee550a..20b0600 100644 --- a/official/nlp/bert/tokenization.py +++ b/official/nlp/bert/tokenization.py @@ -22,7 +22,7 @@ import os import unicodedata from io import open -import megengine as megengine +import megengine logger = megengine.get_logger(__name__) @@ -54,7 +54,7 @@ def whitespace_tokenize(text): return tokens -class BertTokenizer(object): +class BertTokenizer: """Runs end-to-end tokenization: punctuation splitting + wordpiece""" def __init__( @@ -150,7 +150,7 @@ class BertTokenizer(object): return vocab_file -class BasicTokenizer(object): +class BasicTokenizer: """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" def __init__( @@ -243,18 +243,19 @@ class BasicTokenizer(object): # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. - if ( - (cp >= 0x4E00 and cp <= 0x9FFF) - or (cp >= 0x3400 and cp <= 0x4DBF) # - or (cp >= 0x20000 and cp <= 0x2A6DF) # - or (cp >= 0x2A700 and cp <= 0x2B73F) # - or (cp >= 0x2B740 and cp <= 0x2B81F) # - or (cp >= 0x2B820 and cp <= 0x2CEAF) # - or (cp >= 0xF900 and cp <= 0xFAFF) - or (cp >= 0x2F800 and cp <= 0x2FA1F) # - ): # - return True - + cp_range = [ + (0x4E00, 0x9FFF), + (0x3400, 0x4DBF), + (0x20000, 0x2A6DF), + (0x2A700, 0x2B73F), + (0x2B740, 0x2B81F), + (0x2B820, 0x2CEAF), + (0xF900, 0xFAFF), + (0x2F800, 0x2FA1F), + ] + for min_cp, max_cp in cp_range: + if min_cp <= cp <= max_cp: + return True return False def _clean_text(self, text): @@ -271,7 +272,7 @@ class BasicTokenizer(object): return "".join(output) -class WordpieceTokenizer(object): +class WordpieceTokenizer: """Runs WordPiece tokenization.""" def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): @@ -335,7 +336,7 @@ def _is_whitespace(char): """Checks whether `chars` is a whitespace character.""" # \t, \n, and \r are technically contorl characters but we treat them # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": + if char in (" ", "\t", "\n", "\r"): return True cat = unicodedata.category(char) if cat == "Zs": @@ -347,7 +348,7 @@ def _is_control(char): """Checks whether `chars` is a control character.""" # These are technically control characters but we count them as whitespace # characters. - if char == "\t" or char == "\n" or char == "\r": + if char in ("\t", "\n", "\r"): return False cat = unicodedata.category(char) if cat.startswith("C"): @@ -363,10 +364,10 @@ def _is_punctuation(char): # Punctuation class but we treat them as punctuation anyways, for # consistency. if ( - (cp >= 33 and cp <= 47) - or (cp >= 58 and cp <= 64) - or (cp >= 91 and cp <= 96) - or (cp >= 123 and cp <= 126) + (33 <= cp <= 47) + or (58 <= cp <= 64) + or (91 <= cp <= 96) + or (123 <= cp <= 126) ): return True cat = unicodedata.category(char) diff --git a/official/nlp/bert/train.py b/official/nlp/bert/train.py index 7b93903..1b82d84 100644 --- a/official/nlp/bert/train.py +++ b/official/nlp/bert/train.py @@ -13,16 +13,16 @@ import megengine.optimizer as optim from megengine.jit import trace from tqdm import tqdm -from config import get_args from model import BertForSequenceClassification, create_hub_bert from mrpc_dataset import MRPCDataset - -args = get_args() +# pylint: disable=import-outside-toplevel +import config_args +args = config_args.get_args() logger = mge.get_logger(__name__) @trace(symbolic=True) -def net_eval(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None): +def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None): net.eval() results = net(input_ids, segment_ids, input_mask, label_ids) logits, loss = results @@ -49,7 +49,7 @@ def eval(dataloader, net): sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 - for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): + for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch ) @@ -79,7 +79,7 @@ def train(dataloader, net, opt): logger.info("batch size = %d", args.train_batch_size) sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 - for step, batch in enumerate(tqdm(dataloader, desc="Iteration")): + for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch ) -- GitLab