diff --git a/demo/ernie-classification/question_answering.py b/demo/ernie-classification/question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..b09ee78256050d122fbdb5057de399e2aa05349b --- /dev/null +++ b/demo/ernie-classification/question_answering.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning on classification task """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import argparse +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") +parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") +parser.add_argument("--hub_module_dir", type=str, default=None, help="PaddleHub module directory") +parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") +parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.") +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # Select a finetune strategy + strategy = hub.BERTFinetuneStrategy( + weight_decay=args.weight_decay, + learning_rate=args.learning_rate, + warmup_strategy="linear_warmup_decay", + ) + + # Setup runing config for PaddleHub Finetune API + config = hub.RunConfig( + eval_interval=100, + use_cuda=True, + num_epoch=args.num_epoch, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=strategy) + + # loading Paddlehub ERNIE pretrained model + module = hub.Module(name="ernie") + + # Sentence classification dataset reader + reader = hub.reader.ClassifyReader( + dataset=hub.dataset.NLPCC_DBQA(), # download NLPCC_DBQA dataset + vocab_path=module.get_vocab_path(), + max_seq_len=args.max_seq_len) + + num_labels = len(reader.get_labels()) + + input_dict, output_dict, program = module.context( + sign_name="tokens", trainable=True, max_seq_len=args.max_seq_len) + + with fluid.program_guard(program): + label = fluid.layers.data(name="label", shape=[1], dtype='int64') + + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_outputs" for token-level output. + pooled_output = output_dict["pooled_output"] + + # Setup feed list for data feeder + # Must feed all the tensor of ERNIE's module need + feed_list = [ + input_dict["input_ids"].name, input_dict["position_ids"].name, + input_dict["segment_ids"].name, input_dict["input_mask"].name, + label.name + ] + # Define a classfication finetune task by PaddleHub's API + cls_task = hub.create_text_classification_task( + pooled_output, label, num_classes=num_labels) + + # Finetune and evaluate by PaddleHub's API + # will finish training, evaluation, testing, save model automatically + hub.finetune_and_eval( + task=cls_task, + data_reader=reader, + feed_list=feed_list, + config=config) diff --git a/demo/ernie-classification/question_matching.py b/demo/ernie-classification/question_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..64ee1f79b3a03edcd013a82bf6254a3580c742c4 --- /dev/null +++ b/demo/ernie-classification/question_matching.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning on classification task """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import argparse +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") +parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") +parser.add_argument("--hub_module_dir", type=str, default=None, help="PaddleHub module directory") +parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") +parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.") +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # Select a finetune strategy + strategy = hub.BERTFinetuneStrategy( + weight_decay=args.weight_decay, + learning_rate=args.learning_rate, + warmup_strategy="linear_warmup_decay", + ) + + # Setup runing config for PaddleHub Finetune API + config = hub.RunConfig( + eval_interval=100, + use_cuda=True, + num_epoch=args.num_epoch, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=strategy) + + # loading Paddlehub ERNIE pretrained model + module = hub.Module(name="ernie") + + # Sentence classification dataset reader + reader = hub.reader.ClassifyReader( + dataset=hub.dataset.LCQMC(), # download LCQMC dataset + vocab_path=module.get_vocab_path(), + max_seq_len=args.max_seq_len) + + num_labels = len(reader.get_labels()) + + input_dict, output_dict, program = module.context( + sign_name="tokens", trainable=True, max_seq_len=args.max_seq_len) + + with fluid.program_guard(program): + label = fluid.layers.data(name="label", shape=[1], dtype='int64') + + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_outputs" for token-level output. + pooled_output = output_dict["pooled_output"] + + # Setup feed list for data feeder + # Must feed all the tensor of ERNIE's module need + feed_list = [ + input_dict["input_ids"].name, input_dict["position_ids"].name, + input_dict["segment_ids"].name, input_dict["input_mask"].name, + label.name + ] + # Define a classfication finetune task by PaddleHub's API + cls_task = hub.create_text_classification_task( + pooled_output, label, num_classes=num_labels) + + # Finetune and evaluate by PaddleHub's API + # will finish training, evaluation, testing, save model automatically + hub.finetune_and_eval( + task=cls_task, + data_reader=reader, + feed_list=feed_list, + config=config) diff --git a/demo/ernie-classification/run_question_answering.sh b/demo/ernie-classification/run_question_answering.sh new file mode 100644 index 0000000000000000000000000000000000000000..56fa22adeeaaf37623e01c835d79b5b49191b99e --- /dev/null +++ b/demo/ernie-classification/run_question_answering.sh @@ -0,0 +1,10 @@ +export CUDA_VISIBLE_DEVICES=3 + +CKPT_DIR="./ckpt_dbqa" +python -u question_answering.py \ + --batch_size 8 \ + --weight_decay 0.01 \ + --checkpoint_dir $CKPT_DIR \ + --num_epoch 3 \ + --max_seq_len 512 \ + --learning_rate 2e-5 diff --git a/demo/ernie-classification/run_question_matching.sh b/demo/ernie-classification/run_question_matching.sh new file mode 100644 index 0000000000000000000000000000000000000000..2230d8e0b713fc285f083dd3fb26d08a98d744df --- /dev/null +++ b/demo/ernie-classification/run_question_matching.sh @@ -0,0 +1,10 @@ +export CUDA_VISIBLE_DEVICES=0 + +CKPT_DIR="./ckpt_question_matching" +python -u question_matching.py \ + --batch_size 32 \ + --weight_decay 0.0 \ + --checkpoint_dir $CKPT_DIR \ + --num_epoch 3 \ + --max_seq_len 128 \ + --learning_rate 2e-5 diff --git a/demo/ernie-classification/run_fintune_with_hub.sh b/demo/ernie-classification/run_sentiment_cls.sh similarity index 81% rename from demo/ernie-classification/run_fintune_with_hub.sh rename to demo/ernie-classification/run_sentiment_cls.sh index b267e3da28d60738c977820d3b86ad15a5ab0081..34203b1ae2e4d9c6230610f2b31759abb8c57930 100644 --- a/demo/ernie-classification/run_fintune_with_hub.sh +++ b/demo/ernie-classification/run_sentiment_cls.sh @@ -1,7 +1,7 @@ export CUDA_VISIBLE_DEVICES=3 -CKPT_DIR="./ckpt" -python -u finetune_with_hub.py \ +CKPT_DIR="./ckpt_sentiment_cls" +python -u sentiment_cls.py \ --batch_size 32 \ --weight_decay 0.01 \ --checkpoint_dir $CKPT_DIR \ diff --git a/demo/ernie-classification/finetune_with_hub.py b/demo/ernie-classification/sentiment_cls.py similarity index 98% rename from demo/ernie-classification/finetune_with_hub.py rename to demo/ernie-classification/sentiment_cls.py index 841afb39cd37c417435d539c712cd2c9a4862f7f..c5fd6d496b8a23dcde8b285612116ed7b4149508 100644 --- a/demo/ernie-classification/finetune_with_hub.py +++ b/demo/ernie-classification/sentiment_cls.py @@ -49,10 +49,11 @@ if __name__ == '__main__': # Setup runing config for PaddleHub Finetune API config = hub.RunConfig( - eval_interval=10, + eval_interval=100, use_cuda=True, num_epoch=args.num_epoch, batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, strategy=strategy) # loading Paddlehub ERNIE pretrained model diff --git a/paddlehub/dataset/__init__.py b/paddlehub/dataset/__init__.py index b059666e32b1dbd33b95d76a4d4efd9252440e53..1cb0086c10d4e0f7ded663d90695cd347712bbce 100644 --- a/paddlehub/dataset/__init__.py +++ b/paddlehub/dataset/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NLP Dataset from .dataset import InputExample, HubDataset from .chnsenticorp import ChnSentiCorp from .msra_ner import MSRA_NER +from .nlpcc_dbqa import NLPCC_DBQA +from .lcqmc import LCQMC + +# CV Dataset from .dogcat import DogCatDataset as DogCat from .flowers import FlowersDataset as Flowers diff --git a/paddlehub/dataset/chnsenticorp.py b/paddlehub/dataset/chnsenticorp.py index 3cb879dda55e7ad3c86baf8bfa50876fec8a6686..9bac2bb4ac99402d5f39e24693f4ae58951db5bd 100644 --- a/paddlehub/dataset/chnsenticorp.py +++ b/paddlehub/dataset/chnsenticorp.py @@ -16,12 +16,12 @@ from collections import namedtuple import os import csv -from paddlehub.dataset import InputExample -from paddlehub.dataset import HubDataset +from paddlehub.dataset import InputExample, HubDataset from paddlehub.common.downloader import default_downloader from paddlehub.common.dir import DATA_HOME +from paddlehub.common.logger import logger -DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp_data.tar.gz" +DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz" class ChnSentiCorp(HubDataset): @@ -31,8 +31,12 @@ class ChnSentiCorp(HubDataset): """ def __init__(self): - ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( - url=DATA_URL, save_path=DATA_HOME, print_progress=True) + self.dataset_dir = os.path.join(DATA_HOME, "chnsenticorp") + if not os.path.exists(self.dataset_dir): + ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( + url=DATA_URL, save_path=DATA_HOME, print_progress=True) + else: + logger.info("Dataset {} already cached.".format(self.dataset_dir)) self._load_train_examples() self._load_test_examples() @@ -69,6 +73,7 @@ class ChnSentiCorp(HubDataset): reader = csv.reader(f, delimiter="\t", quotechar=quotechar) examples = [] seq_id = 0 + header = next(reader) # skip header for line in reader: example = InputExample( guid=seq_id, label=line[0], text_a=line[1]) @@ -81,4 +86,4 @@ class ChnSentiCorp(HubDataset): if __name__ == "__main__": ds = ChnSentiCorp() for e in ds.get_train_examples(): - print(e) + print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) diff --git a/paddlehub/dataset/dataset.py b/paddlehub/dataset/dataset.py index aefd698ed891548bd0588f906bf83a185fc7931e..ec13f9c1a9cdd0afb5cac73c5cb89d4af3779880 100644 --- a/paddlehub/dataset/dataset.py +++ b/paddlehub/dataset/dataset.py @@ -22,7 +22,6 @@ class InputExample(object): def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. - Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single diff --git a/paddlehub/dataset/lcqmc.py b/paddlehub/dataset/lcqmc.py new file mode 100644 index 0000000000000000000000000000000000000000..3a17733e3eacdb9852debc4db24788ff87482e3d --- /dev/null +++ b/paddlehub/dataset/lcqmc.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import os +import csv + +from paddlehub.dataset import InputExample, HubDataset +from paddlehub.common.downloader import default_downloader +from paddlehub.common.dir import DATA_HOME +from paddlehub.common.logger import logger + +DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz" + + +class LCQMC(HubDataset): + def __init__(self): + self.dataset_dir = os.path.join(DATA_HOME, "lcqmc") + if not os.path.exists(self.dataset_dir): + ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( + url=DATA_URL, save_path=DATA_HOME, print_progress=True) + else: + logger.info("Dataset {} already cached.".format(self.dataset_dir)) + + self._load_train_examples() + self._load_test_examples() + self._load_dev_examples() + + def _load_train_examples(self): + self.train_file = os.path.join(self.dataset_dir, "train.tsv") + self.train_examples = self._read_tsv(self.train_file) + + def _load_dev_examples(self): + self.dev_file = os.path.join(self.dataset_dir, "dev.tsv") + self.dev_examples = self._read_tsv(self.dev_file) + + def _load_test_examples(self): + self.test_file = os.path.join(self.dataset_dir, "test.tsv") + self.test_examples = self._read_tsv(self.test_file) + + def get_train_examples(self): + return self.train_examples + + def get_dev_examples(self): + return self.dev_examples + + def get_test_examples(self): + return self.test_examples + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _read_tsv(self, input_file, quotechar=None): + """Reads a tab separated value file.""" + with open(input_file, "r") as f: + reader = csv.reader(f, delimiter="\t", quotechar=quotechar) + examples = [] + seq_id = 0 + header = next(reader) # skip header + for line in reader: + example = InputExample( + guid=seq_id, label=line[2], text_a=line[0], text_b=line[1]) + seq_id += 1 + examples.append(example) + + return examples + + +if __name__ == "__main__": + ds = LCQMC() + for e in ds.get_train_examples(): + print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) diff --git a/paddlehub/dataset/msra_ner.py b/paddlehub/dataset/msra_ner.py index fd091f7ebeca973e7e0cb7d66f7231ce3e9142c0..37d0d78de32a1475d151b94696cc3bc0ad399afb 100644 --- a/paddlehub/dataset/msra_ner.py +++ b/paddlehub/dataset/msra_ner.py @@ -19,14 +19,19 @@ from collections import namedtuple from paddlehub.common.downloader import default_downloader from paddlehub.common.dir import DATA_HOME +from paddlehub.common.logger import logger DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz" class MSRA_NER(object): def __init__(self): - ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( - url=DATA_URL, save_path=DATA_HOME, print_progress=True) + self.dataset_dir = os.path.join(DATA_HOME, "msra_ner") + if not os.path.exists(self.dataset_dir): + ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( + url=DATA_URL, save_path=DATA_HOME, print_progress=True) + else: + logger.info("Dataset {} already cached.".format(self.dataset_dir)) self._load_label_map() self._load_train_examples() diff --git a/paddlehub/dataset/nlpcc_dbqa.py b/paddlehub/dataset/nlpcc_dbqa.py new file mode 100644 index 0000000000000000000000000000000000000000..110930ddc446235be791e2e5a3353156cf1fb890 --- /dev/null +++ b/paddlehub/dataset/nlpcc_dbqa.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import os +import csv + +from paddlehub.dataset import InputExample, HubDataset +from paddlehub.common.downloader import default_downloader +from paddlehub.common.dir import DATA_HOME +from paddlehub.common.logger import logger + +DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz" + + +class NLPCC_DBQA(HubDataset): + def __init__(self): + self.dataset_dir = os.path.join(DATA_HOME, "nlpcc-dbqa") + if not os.path.exists(self.dataset_dir): + ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( + url=DATA_URL, save_path=DATA_HOME, print_progress=True) + else: + logger.info("Dataset {} already cached.".format(self.dataset_dir)) + + self._load_train_examples() + self._load_test_examples() + self._load_dev_examples() + + def _load_train_examples(self): + self.train_file = os.path.join(self.dataset_dir, "train.tsv") + self.train_examples = self._read_tsv(self.train_file) + + def _load_dev_examples(self): + self.dev_file = os.path.join(self.dataset_dir, "dev.tsv") + self.dev_examples = self._read_tsv(self.dev_file) + + def _load_test_examples(self): + self.test_file = os.path.join(self.dataset_dir, "test.tsv") + self.test_examples = self._read_tsv(self.test_file) + + def get_train_examples(self): + return self.train_examples + + def get_dev_examples(self): + return self.dev_examples + + def get_test_examples(self): + return self.test_examples + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _read_tsv(self, input_file, quotechar=None): + """Reads a tab separated value file.""" + with open(input_file, "r") as f: + reader = csv.reader(f, delimiter="\t", quotechar=quotechar) + examples = [] + seq_id = 0 + header = next(reader) # skip header + for line in reader: + example = InputExample( + guid=seq_id, label=line[3], text_a=line[1], text_b=line[2]) + seq_id += 1 + examples.append(example) + + return examples + + +if __name__ == "__main__": + ds = NLPCC_DBQA() + for e in ds.get_train_examples(): + print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label)) diff --git a/paddlehub/finetune/evaluate.py b/paddlehub/finetune/evaluate.py index 758a0b665bb4341b06afdddc0f94bdf1a1c267af..44d23d7a245bea5a80b57e34ca89ef88dfa593e3 100644 --- a/paddlehub/finetune/evaluate.py +++ b/paddlehub/finetune/evaluate.py @@ -48,6 +48,8 @@ def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None): feed=data_feeder.feed(batch), fetch_list=[loss.name, accuracy.name]) num_eval_examples += num_batch_examples + if num_eval_examples % 10000 == 0: + logger.info("{} examples evaluated.".format(num_eval_examples)) acc_sum += accuracy_v * num_batch_examples loss_sum += loss_v * num_batch_examples eval_time_used = time.time() - eval_time_begin diff --git a/paddlehub/finetune/task.py b/paddlehub/finetune/task.py index db2a4b32a2298fb24e4c7421129ce1693265fb08..a6dd964b25fa236fad2874a73bc96167eb628999 100644 --- a/paddlehub/finetune/task.py +++ b/paddlehub/finetune/task.py @@ -109,8 +109,15 @@ def create_img_classification_task(feature, num_classes, hidden_units=None): """ - Append a multi-layer perceptron classifier for binary classification base - on input feature + Create the transfer learning task for image classification. + Args: + feature: + + Return: + Task + + Raise: + None """ cls_feats = feature # append fully connected layer according to hidden_units diff --git a/paddlehub/module/checker.py b/paddlehub/module/checker.py index d8ecb4ff7e7f2d41bd8f2bb8fa57e9efd8f9b610..ad5a50529df24887e44591968c22d8ed74c7363b 100644 --- a/paddlehub/module/checker.py +++ b/paddlehub/module/checker.py @@ -154,16 +154,16 @@ class ModuleChecker: if not os.path.exists(file_path): if file_info.is_need: logger.error( - "module lack of necessary file [%s]" % file_path) + "Module incompleted! Missing file [%s]" % file_path) return False else: if file_type == check_info_pb2.FILE: if not os.path.isfile(file_path): - logger.error("file type error %s" % file_path) + logger.error("File type error %s" % file_path) return False if file_type == check_info_pb2.DIR: if not os.path.isdir(file_path): - logger.error("file type error %s" % file_path) + logger.error("File type error %s" % file_path) return False return True diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index 98bfa11835dc97a28fb622d0f72466c81b6b2971..59edcbd3330bf974fd2dc167ecd8257f0910d7f5 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -61,13 +61,14 @@ class LocalModuleManager: self.all_modules(update=True) if module_name in self.modules_dict: module_dir = self.modules_dict[module_name] - tips = "module %s already install in %s" % (module_name, module_dir) + tips = "Module %s already installed in %s" % (module_name, + module_dir) return True, tips, module_dir url = hub.default_hub_server.get_module_url( module_name, version=module_version) #TODO(wuzewu): add compatibility check if not url: - tips = "can't found module %s" % module_name + tips = "Can't find module %s" % module_name if module_version: tips += " with version %s" % module_version return False, tips, None diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index a9fbbb253810874ae6e594eb9640d7ba201e376c..8fb644c217296a9512963050b0d0acf39fb3b4d2 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -128,7 +128,7 @@ class Module(object): self._generate_module_info(module_info) self._init_with_signature(signatures=signatures) else: - raise "Error! HubModule can't init with nothing" + raise ValueError("Error! Module initialized parameter is empty") def _init_with_name(self, name): logger.info("Try installing module %s" % name) @@ -191,7 +191,8 @@ class Module(object): def _init_with_module_file(self, module_dir): checker = ModuleChecker(module_dir) if not checker.check(): - logger.error("Module init failed on {}".format(module_dir)) + logger.error( + "Module initialization failed on {}".format(module_dir)) exit(1) self.helper = ModuleHelper(module_dir) @@ -223,7 +224,9 @@ class Module(object): self.program = signatures[0].inputs[0].block.program for sign in signatures: if sign.name in self.signatures: - raise "Error! signature array contains repeat signatrue %s" % sign + raise ValueError( + "Error! Signature array contains duplicated signatrues %s" % + sign) if self.default_signature is None and sign.for_predict: self.default_signature = sign self.signatures[sign.name] = sign @@ -265,7 +268,7 @@ class Module(object): self.module_info = {} else: if not utils.is_yaml_file(module_info): - logger.critical("module info file should in yaml format") + logger.critical("Module info file should be yaml format") exit(1) self.module_info = yaml_parser.parse(module_info) self.author = self.module_info.get('author', 'UNKNOWN') @@ -532,7 +535,7 @@ class Module(object): return self.get_name_prefix() + var_name def _check_signatures(self): - assert self.signatures, "signature array should not be None" + assert self.signatures, "Signature array should not be None" for key, sign in self.signatures.items(): assert isinstance(sign, diff --git a/paddlehub/module/signature.py b/paddlehub/module/signature.py index 11d5bdea3b6466131b27ceab2f467d342ebd3251..243dc5dc3b465dfe25b16a9a508d97dbda9a42b0 100644 --- a/paddlehub/module/signature.py +++ b/paddlehub/module/signature.py @@ -47,13 +47,11 @@ class Signature: self.name = name for item in inputs: assert isinstance( - item, - Variable), "the item of inputs list shoule be paddle Variable" + item, Variable), "the item of inputs list shoule be Variable" for item in outputs: assert isinstance( - item, - Variable), "the item of outputs list shoule be paddle Variable" + item, Variable), "the item of outputs list shoule be Variable" self.inputs = inputs self.outputs = outputs diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index abed54c0bc70486e7c666ab70f09ef755cfa645f..549b032ba5a812e15b3beb4ba7437a2941a157bd 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -18,6 +18,7 @@ import numpy as np from collections import namedtuple from paddlehub.reader import tokenization +from paddlehub.common.logger import logger from .batching import pad_batch_data @@ -46,18 +47,12 @@ class BaseReader(object): self.label_map = {} for index, label in enumerate(self.dataset.get_labels()): self.label_map[label] = index - print("Dataset label map = {}".format(self.label_map)) + logger.info("Dataset label map = {}".format(self.label_map)) self.current_example = 0 self.current_epoch = 0 self.num_examples = 0 - # if label_map_config: - # with open(label_map_config) as f: - # self.label_map = json.load(f) - # else: - # self.label_map = None - self.num_examples = {'train': -1, 'dev': -1, 'test': -1} def get_train_examples(self): @@ -160,24 +155,13 @@ class BaseReader(object): position_ids = list(range(len(token_ids))) if self.label_map: + if example.label not in self.label_map: + raise KeyError( + "example.label = {%s} not in label" % example.label) label_id = self.label_map[example.label] else: label_id = example.label - # Record = namedtuple( - # 'Record', - # ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid']) - - # qid = None - # if "qid" in example._fields: - # qid = example.qid - - # record = Record( - # token_ids=token_ids, - # text_type_ids=text_type_ids, - # position_ids=position_ids, - # label_id=label_id, - # qid=qid) Record = namedtuple( 'Record', ['token_ids', 'text_type_ids', 'position_ids', 'label_id']) @@ -211,10 +195,6 @@ class BaseReader(object): if batch_records: yield self._pad_batch_records(batch_records) - # def get_num_examples(self, input_file): - # examples = self._read_tsv(input_file) - # return len(examples) - def get_num_examples(self, phase): """Get number of examples for train, dev or test.""" if phase not in ['train', 'val', 'dev', 'test']: diff --git a/paddlehub/version.py b/paddlehub/version.py index 61c98cc21c95e4098a5aee992863bdc72f91b0be..e6a0c61fdcfdf0ee0e0f613bd038b1c2a1b73ace 100644 --- a/paddlehub/version.py +++ b/paddlehub/version.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PaddleHub version string """ -hub_version = "0.3.1.alpha" +hub_version = "0.4.0.alpha" module_proto_version = "0.1.0"