diff --git a/examples/bert/bert_classifier.py b/examples/bert/bert_classifier.py index ef43ae2076e665a88fd896dd7e6f830c9e38640c..c55eb2554d9074b03135c42aa746fa6c7a33bf27 100644 --- a/examples/bert/bert_classifier.py +++ b/examples/bert/bert_classifier.py @@ -18,7 +18,8 @@ from hapi.metrics import Accuracy from hapi.configure import Config from hapi.text.bert import BertEncoder from paddle.fluid.dygraph import Linear, Layer -from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input +from hapi.model import set_device, Model, Input +from hapi.loss import SoftmaxWithCrossEntropy import hapi.text.tokenizer.tokenization as tokenization from hapi.text.bert import Optimizer, BertConfig, BertDataLoader, BertInputExample diff --git a/examples/bert_leveldb/bert_classifier.py b/examples/bert_leveldb/bert_classifier.py index 11bc85758ebbe81ae68b3c141d4582ee8d41508c..624e49c4d8c44d05c52f2e79a65dd8399a5b9f4c 100644 --- a/examples/bert_leveldb/bert_classifier.py +++ b/examples/bert_leveldb/bert_classifier.py @@ -18,7 +18,8 @@ from hapi.metrics import Accuracy from hapi.configure import Config from hapi.text.bert import BertEncoder from paddle.fluid.dygraph import Linear, Layer -from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input +from hapi.model import set_device, Model, Input +from hapi.loss import SoftmaxWithCrossEntropy import hapi.text.tokenizer.tokenization as tokenization from hapi.text.bert import Optimizer, BertConfig, BertDataLoader, BertInputExample diff --git a/examples/bmn/modeling.py b/examples/bmn/modeling.py index f0fa26e1a687fc1d524870a03be470edf280fc9c..bfd65b318ad5f0c8a3ea2191f9a7c2a4d18b691e 100644 --- a/examples/bmn/modeling.py +++ b/examples/bmn/modeling.py @@ -17,8 +17,9 @@ from paddle.fluid import ParamAttr import numpy as np import math -from hapi.model import Model, Loss -from hapi.download import get_weights_path +from hapi.model import Model +from hapi.loss import Loss +from hapi.download import get_weights_path_from_url __all__ = ["BMN", "BmnLoss", "bmn"] @@ -459,7 +460,7 @@ def bmn(tscale, model = BMN(tscale, dscale, prop_boundary_ratio, num_sample, num_sample_perbin) if pretrained: - weight_path = get_weights_path(*(pretrain_infos['bmn'])) + weight_path = get_weights_path_from_url(*(pretrain_infos['bmn'])) assert weight_path.endswith('.pdparams'), \ "suffix of weight must be .pdparams" model.load(weight_path) diff --git a/examples/cyclegan/cyclegan.py b/examples/cyclegan/cyclegan.py index 2c5cbd364c35c71dd5bd1d6831bb5b6d4ada07fb..076e13d3ebe079fbb639d6ec80d432403af03cfd 100644 --- a/examples/cyclegan/cyclegan.py +++ b/examples/cyclegan/cyclegan.py @@ -19,7 +19,8 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid -from hapi.model import Model, Loss +from hapi.model import Model +from hapi.loss import Loss from layers import ConvBN, DeConvBN diff --git a/examples/image_classification/imagenet_dataset.py b/examples/image_classification/imagenet_dataset.py index 25dcc338e20e53e75f3637ea8f8e3d492a1240e1..27c41d6fb4cfe752f311d2e2c65380aa29bc4323 100644 --- a/examples/image_classification/imagenet_dataset.py +++ b/examples/image_classification/imagenet_dataset.py @@ -50,7 +50,7 @@ class ImageNetDataset(DatasetFolder): def __getitem__(self, idx): img_path, label = self.samples[idx] img = cv2.imread(img_path).astype(np.float32) - label = np.array([label]) + label = np.array([label]).astype(np.int64) return self.transform(img), label def __len__(self): diff --git a/examples/image_classification/main.py b/examples/image_classification/main.py index 64396a6042f80cfbd53ff775ab95c41330894a9a..e5aea412fbeb619d90002f3c9f817788b021821a 100644 --- a/examples/image_classification/main.py +++ b/examples/image_classification/main.py @@ -27,7 +27,8 @@ import paddle.fluid as fluid from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.io import BatchSampler, DataLoader -from hapi.model import CrossEntropy, Input, set_device +from hapi.model import Input, set_device +from hapi.loss import CrossEntropy from hapi.distributed import DistributedBatchSampler from hapi.metrics import Accuracy import hapi.vision.models as models diff --git a/examples/ocr/seq2seq_attn.py b/examples/ocr/seq2seq_attn.py index 675e4e4ab0b30874dffd1b0bbc84b7c54c42354b..66da91ce7d84c458bd6424da8c364a9ed25776c4 100644 --- a/examples/ocr/seq2seq_attn.py +++ b/examples/ocr/seq2seq_attn.py @@ -20,7 +20,8 @@ import paddle.fluid.layers as layers from paddle.fluid.layers import BeamSearchDecoder from hapi.text import RNNCell, RNN, DynamicDecode -from hapi.model import Model, Loss +from hapi.model import Model +from hapi.loss import Loss class ConvBNPool(fluid.dygraph.Layer): diff --git a/examples/sequence_tagging/README.md b/examples/sequence_tagging/README.md index 898f3abbcbc6bbee447b258c554ef4cde98143e4..b36e9cda77efe701dcc1e342e65f50c70fd69c8d 100644 --- a/examples/sequence_tagging/README.md +++ b/examples/sequence_tagging/README.md @@ -14,7 +14,7 @@ Sequence Tagging,是一个序列标注模型,模型可用于实现,分词 #### 1.PaddlePaddle 安装 -本项目依赖 PaddlePaddle 1.7 及以上版本和PaddleHub 1.0.0及以上版本 ,PaddlePaddle安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start),PaddleHub安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)。 +本项目依赖 PaddlePaddle 1.8 及以上版本和PaddleHub 1.0.0及以上版本 ,PaddlePaddle安装请参考官网 [快速安装](http://www.paddlepaddle.org/paddle#quick-start),PaddleHub安装参考 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)。 > Warning: GPU 和 CPU 版本的 PaddlePaddle 分别是 paddlepaddle-gpu 和 paddlepaddle,请安装时注意区别。 diff --git a/examples/sequence_tagging/predict.py b/examples/sequence_tagging/predict.py index bcb39265d7ef8a08ce6700d599b37a4f4ae19054..5067eb7c844972dd2a625901e841196b527c6e8a 100644 --- a/examples/sequence_tagging/predict.py +++ b/examples/sequence_tagging/predict.py @@ -21,6 +21,7 @@ from __future__ import print_function import io import os import sys +import six import math import argparse import numpy as np @@ -71,7 +72,12 @@ def main(args): word_len = length[i] word_ids = results[i][:word_len] tags = [dataset.id2label_dict[str(id)] for id in word_ids] - f.write("\002".join(tags) + "\n") + if six.PY3: + tags = [bytes(tag, encoding="utf8") for tag in tags] + out = b"\002".join(tags) + b"\n" + f.write(out) + else: + f.write("\002".join(tags) + "\n") if __name__ == '__main__': diff --git a/examples/sequence_tagging/reader.py b/examples/sequence_tagging/reader.py index 7a772b3fbbc80478dfc4e9096273a60ade05c79a..991a24e867c5d171247a1497f744cb96d0758f8b 100644 --- a/examples/sequence_tagging/reader.py +++ b/examples/sequence_tagging/reader.py @@ -20,7 +20,6 @@ from __future__ import print_function import io import os -import leveldb import numpy as np import shutil from functools import partial diff --git a/examples/sequence_tagging/train.py b/examples/sequence_tagging/train.py index 7d5a9337d3b0da6f116262f1b30def68b828e00b..41422fc7d722a2ebea606080151da6807156ad18 100644 --- a/examples/sequence_tagging/train.py +++ b/examples/sequence_tagging/train.py @@ -29,7 +29,8 @@ work_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.join(work_dir, "../")) from hapi.metrics import Metric -from hapi.model import Model, Input, Loss, set_device +from hapi.model import Model, Input, set_device +from hapi.loss import Loss from hapi.text.text import SequenceTagging from utils.check import check_gpu, check_version diff --git a/examples/tsm/infer.py b/examples/tsm/infer.py index df5beafa7fe033af55fd2407ec32a67987d00804..cac9745fc8ddc1ad33dc3e95dd82e8f2dbe24277 100644 --- a/examples/tsm/infer.py +++ b/examples/tsm/infer.py @@ -20,6 +20,7 @@ import argparse import numpy as np from hapi.model import Input, set_device +from hapi.vision.transforms import Compose from check import check_gpu, check_version from modeling import tsm_resnet50 diff --git a/examples/tsm/main.py b/examples/tsm/main.py index e6434fe4cd67c7fef85cce67e30d08dbd6bc06fd..deef9f6b8349033f2d7ce83b40f583ca24a56a7e 100644 --- a/examples/tsm/main.py +++ b/examples/tsm/main.py @@ -22,8 +22,10 @@ import numpy as np from paddle import fluid from paddle.fluid.dygraph.parallel import ParallelEnv -from hapi.model import Model, CrossEntropy, Input, set_device +from hapi.model import Model, Input, set_device +from hapi.loss import CrossEntropy from hapi.metrics import Accuracy +from hapi.vision.transforms import Compose from modeling import tsm_resnet50 from check import check_gpu, check_version @@ -34,11 +36,10 @@ from utils import print_arguments def make_optimizer(step_per_epoch, parameter_list=None): boundaries = [e * step_per_epoch for e in [40, 60]] - values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)] + values = [FLAGS.lr * (0.1**i) for i in range(len(boundaries) + 1)] learning_rate = fluid.layers.piecewise_decay( - boundaries=boundaries, - values=values) + boundaries=boundaries, values=values) optimizer = fluid.optimizer.Momentum( learning_rate=learning_rate, regularization=fluid.regularizer.L2Decay(1e-4), @@ -52,29 +53,27 @@ def main(): device = set_device(FLAGS.device) fluid.enable_dygraph(device) if FLAGS.dynamic else None - train_transform = Compose([GroupScale(), - GroupMultiScaleCrop(), - GroupRandomCrop(), - GroupRandomFlip(), - NormalizeImage()]) + train_transform = Compose([ + GroupScale(), GroupMultiScaleCrop(), GroupRandomCrop(), + GroupRandomFlip(), NormalizeImage() + ]) train_dataset = KineticsDataset( - file_list=os.path.join(FLAGS.data, 'train_10.list'), - pickle_dir=os.path.join(FLAGS.data, 'train_10'), - label_list=os.path.join(FLAGS.data, 'label_list'), - transform=train_transform) - val_transform = Compose([GroupScale(), - GroupCenterCrop(), - NormalizeImage()]) + file_list=os.path.join(FLAGS.data, 'train_10.list'), + pickle_dir=os.path.join(FLAGS.data, 'train_10'), + label_list=os.path.join(FLAGS.data, 'label_list'), + transform=train_transform) + val_transform = Compose( + [GroupScale(), GroupCenterCrop(), NormalizeImage()]) val_dataset = KineticsDataset( - file_list=os.path.join(FLAGS.data, 'val_10.list'), - pickle_dir=os.path.join(FLAGS.data, 'val_10'), - label_list=os.path.join(FLAGS.data, 'label_list'), - mode='val', - transform=val_transform) + file_list=os.path.join(FLAGS.data, 'val_10.list'), + pickle_dir=os.path.join(FLAGS.data, 'val_10'), + label_list=os.path.join(FLAGS.data, 'label_list'), + mode='val', + transform=val_transform) pretrained = FLAGS.eval_only and FLAGS.weights is None - model = tsm_resnet50(num_classes=train_dataset.num_classes, - pretrained=pretrained) + model = tsm_resnet50( + num_classes=train_dataset.num_classes, pretrained=pretrained) step_per_epoch = int(len(train_dataset) / FLAGS.batch_size \ / ParallelEnv().nranks) @@ -117,7 +116,9 @@ def main(): if __name__ == '__main__': parser = argparse.ArgumentParser("CNN training on TSM") parser.add_argument( - "--data", type=str, default='dataset/kinetics', + "--data", + type=str, + default='dataset/kinetics', help="path to dataset root directory") parser.add_argument( "--device", type=str, default='gpu', help="device to use, gpu or cpu") diff --git a/examples/tsm/modeling.py b/examples/tsm/modeling.py index c2422ed3f1cf57e9fd029bb01e04e55d5296e918..aafd6a4557e01222bf672f6bde47dcdbeee06ce3 100644 --- a/examples/tsm/modeling.py +++ b/examples/tsm/modeling.py @@ -18,7 +18,7 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = ["TSM_ResNet", "tsm_resnet50"] @@ -122,6 +122,7 @@ class TSM_ResNet(Model): seg_num (int): segment number of each video sample. Default 8. num_classes (int): video class number. Default 400. """ + def __init__(self, num_layers=50, seg_num=8, num_classes=400): super(TSM_ResNet, self).__init__() @@ -136,7 +137,11 @@ class TSM_ResNet(Model): num_filters = [64, 128, 256, 512] self.conv = ConvBNLayer( - num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') self.pool2d_max = Pool2D( pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') @@ -193,7 +198,7 @@ def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True): assert num_layers in pretrain_infos.keys(), \ "TSM-ResNet{} do not have pretrained weights now, " \ "pretrained should be set as False".format(num_layers) - weight_path = get_weights_path(*(pretrain_infos[num_layers])) + weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers])) assert weight_path.endswith('.pdparams'), \ "suffix of weight must be .pdparams" model.load(weight_path) diff --git a/examples/tsm/transforms.py b/examples/tsm/transforms.py index 230e2f8a332797002d76cf1317f425adea893da2..1fa9f1197c7ab482f78ccd9f12a6fb216593b946 100644 --- a/examples/tsm/transforms.py +++ b/examples/tsm/transforms.py @@ -21,24 +21,7 @@ import logging logger = logging.getLogger(__name__) __all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop', - 'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage', - 'Compose'] - - -class Compose(object): - def __init__(self, transforms=[]): - self.transforms = transforms - - def __call__(self, *data): - for f in self.transforms: - try: - data = f(*data) - except Exception as e: - stack_info = traceback.format_exc() - logger.info("fail to perform transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e - return data + 'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage'] class GroupScale(object): diff --git a/examples/yolov3/main.py b/examples/yolov3/main.py index 60adb8de8f41d35f26a077ab91c3f567cc381d9e..e3c773fbc40e8afd19c568d993d903f6a52240dd 100644 --- a/examples/yolov3/main.py +++ b/examples/yolov3/main.py @@ -27,7 +27,7 @@ from paddle.io import DataLoader from hapi.model import Model, Input, set_device from hapi.distributed import DistributedBatchSampler -from hapi.vision.transforms import BatchCompose +from hapi.vision.transforms import Compose, BatchCompose from modeling import yolov3_darknet53, YoloLoss from coco import COCODataset diff --git a/examples/yolov3/modeling.py b/examples/yolov3/modeling.py index 0b74bf93449a3eba2be47126525db40b434e89fe..982c0beae8f215a0bc00441895e8c9bd883b83cb 100644 --- a/examples/yolov3/modeling.py +++ b/examples/yolov3/modeling.py @@ -20,8 +20,9 @@ from paddle.fluid.dygraph.nn import Conv2D, BatchNorm from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay -from hapi.model import Model, Loss -from hapi.download import get_weights_path +from hapi.model import Model +from hapi.loss import Loss +from hapi.download import get_weights_path_from_url from hapi.vision.models import darknet53 __all__ = ['YoloLoss', 'YOLOv3', 'yolov3_darknet53'] @@ -315,7 +316,7 @@ def _yolov3_darknet(num_layers=53, assert num_layers in pretrain_infos.keys(), \ "YOLOv3-DarkNet{} do not have pretrained weights now, " \ "pretrained should be set as False".format(num_layers) - weight_path = get_weights_path(*(pretrain_infos[num_layers])) + weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers])) assert weight_path.endswith('.pdparams'), \ "suffix of weight must be .pdparams" model.load(weight_path) diff --git a/examples/yolov3/transforms.py b/examples/yolov3/transforms.py index 4eca95a95d692cbe9e9db654cf727e289361ff5f..a220b5de34ec9f5ff08e819052ee6cd5ea9905a7 100644 --- a/examples/yolov3/transforms.py +++ b/examples/yolov3/transforms.py @@ -20,7 +20,6 @@ import traceback import numpy as np __all__ = [ - "Compose", 'ColorDistort', 'RandomExpand', 'RandomCrop', @@ -34,37 +33,6 @@ __all__ = [ ] -class Compose(object): - """Composes several transforms together. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, *data): - for f in self.transforms: - try: - data = f(*data) - except Exception as e: - stack_info = traceback.format_exc() - print("fail to perform transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e - return data - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - class ColorDistort(object): """Random color distortion. diff --git a/hapi/__init__.py b/hapi/__init__.py index eb3f008db4e690a5cf8999862432bedddbf2ef1c..3860aafc7306c764cfc055745038a78ba99de1fd 100644 --- a/hapi/__init__.py +++ b/hapi/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from hapi import logger from hapi.configure import Config from hapi import callbacks from hapi import datasets @@ -22,16 +22,11 @@ from hapi import model from hapi import progressbar from hapi import text from hapi import vision +from hapi import loss + +logger.setup_logger() __all__ = [ - 'Config', - 'callbacks', - 'datasets', - 'distributed', - 'download', - 'metrics', - 'model', - 'progressbar', - 'text', - 'vision', + 'Config', 'callbacks', 'datasets', 'distributed', 'download', 'metrics', + 'model', 'progressbar', 'text', 'vision', 'loss' ] diff --git a/hapi/datasets/flowers.py b/hapi/datasets/flowers.py index c360e8fc287dd97fc5747ae9f65c668e3b7a1cf1..9d543c318dff1540122842aaa5e8a0ae9592988b 100644 --- a/hapi/datasets/flowers.py +++ b/hapi/datasets/flowers.py @@ -121,7 +121,7 @@ class Flowers(Dataset): image = np.array(Image.open(io.BytesIO(image))) if self.transform is not None: - image, label = self.transform(image, label) + image = self.transform(image) return image, label.astype('int64') diff --git a/hapi/datasets/folder.py b/hapi/datasets/folder.py index c13710ea033dd62b665d60967d3acc91cb84c4ef..1d8c2a3e54403710b2c054e9fbf23af989eb1a52 100644 --- a/hapi/datasets/folder.py +++ b/hapi/datasets/folder.py @@ -150,7 +150,7 @@ class DatasetFolder(Dataset): path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: - sample, target = self.transform(sample) + sample = self.transform(sample) return sample, target diff --git a/hapi/datasets/mnist.py b/hapi/datasets/mnist.py index 11b5f310ffc6baf2df85a9bcae716c54715097fe..a264f0a387c62a71dd52fa8e3362fd2e9bd59a24 100644 --- a/hapi/datasets/mnist.py +++ b/hapi/datasets/mnist.py @@ -149,7 +149,7 @@ class MNIST(Dataset): def __getitem__(self, idx): image, label = self.images[idx], self.labels[idx] if self.transform is not None: - image, label = self.transform(image, label) + image = self.transform(image) return image, label def __len__(self): diff --git a/hapi/download.py b/hapi/download.py index e9a89ba53bc3bc74f03977659156121bce6db577..58c60c3885950aa0799b67a6824c3ca188e0840a 100644 --- a/hapi/download.py +++ b/hapi/download.py @@ -29,7 +29,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv import logging logger = logging.getLogger(__name__) -__all__ = ['get_weights_path', 'is_url'] +__all__ = ['get_weights_path_from_url', 'is_url'] WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights") @@ -45,48 +45,56 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') -def get_weights_path(url, md5sum=None): +def get_weights_path_from_url(url, md5sum=None): """Get weights path from WEIGHT_HOME, if not exists, download it from url. + + Args: + url (str): download url + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded weights. """ - path, _ = get_path(url, WEIGHTS_HOME, md5sum) + path = get_path_from_url(url, WEIGHTS_HOME, md5sum) return path -def map_path(url, root_dir): +def _map_path(url, root_dir): # parse path after download under root_dir fname = osp.split(url)[-1] fpath = fname return osp.join(root_dir, fpath) -def get_path(url, root_dir, md5sum=None, check_exist=True): +def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): """ Download from given url to root_dir. if file or directory specified by url is exists under root_dir, return the path directly, otherwise download from url and decompress it, return the path. - url (str): download url - root_dir (str): root dir for downloading, it should be - WEIGHTS_HOME or DATASET_HOME - md5sum (str): md5 sum of download package + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded models & weights & datasets. """ assert is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir - fullpath = map_path(url, root_dir) + fullpath = _map_path(url, root_dir) - exist_flag = False if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): - exist_flag = True - if ParallelEnv().local_rank == 0: - logger.info("Found {}".format(fullpath)) + logger.info("Found {}".format(fullpath)) else: if ParallelEnv().local_rank == 0: fullpath = _download(url, root_dir, md5sum) else: while not os.path.exists(fullpath): time.sleep(1) - return fullpath, exist_flag + return fullpath def _download(url, path, md5sum=None): @@ -109,8 +117,8 @@ def _download(url, path, md5sum=None): else: raise RuntimeError("Download from {} failed. " "Retry limit reached".format(url)) - if ParallelEnv().local_rank == 0: - logger.info("Downloading {} from {}".format(fname, url)) + + logger.info("Downloading {} from {}".format(fname, url)) req = requests.get(url, stream=True) if req.status_code != 200: @@ -141,8 +149,8 @@ def _download(url, path, md5sum=None): def _md5check(fullname, md5sum=None): if md5sum is None: return True - if ParallelEnv().local_rank == 0: - logger.info("File {} md5 checking...".format(fullname)) + + logger.info("File {} md5 checking...".format(fullname)) md5 = hashlib.md5() with open(fullname, 'rb') as f: for chunk in iter(lambda: f.read(4096), b""): @@ -150,8 +158,7 @@ def _md5check(fullname, md5sum=None): calc_md5sum = md5.hexdigest() if calc_md5sum != md5sum: - if ParallelEnv().local_rank == 0: - logger.info("File {} md5 check failed, {}(calc) != " - "{}(base)".format(fullname, calc_md5sum, md5sum)) + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) return False return True diff --git a/hapi/logger.py b/hapi/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..83e8a35e5da0ea6a80705778ddf41a73f26b80e5 --- /dev/null +++ b/hapi/logger.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging + +from paddle.fluid.dygraph.parallel import ParallelEnv + + +def setup_logger(output=None, name="hapi", log_level=logging.INFO): + """ + Initialize logger of hapi and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger. Default: 'hapi'. + log_level (enum): log level. eg.'INFO', 'DEBUG', 'ERROR'. Default: logging.INFO. + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(log_level) + + # stdout logging: only local rank==0 + local_rank = ParallelEnv().local_rank + if local_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(log_level) + + format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ch.setFormatter(logging.Formatter(format_str)) + logger.addHandler(ch) + + # file logging if output is not None: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if local_rank > 0: + filename = filename + ".rank{}".format(local_rank) + + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) + + fh = logging.StreamHandler(filename) + fh.setLevel(log_level) + fh.setFormatter(logging.Formatter(format_str)) + logger.addHandler(fh) + + return logger diff --git a/hapi/loss.py b/hapi/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7abddf22f1519b6bd1a649f663b22f315366ca7a --- /dev/null +++ b/hapi/loss.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import os + +from paddle import fluid +from paddle.fluid.framework import in_dygraph_mode, Variable +from paddle.fluid.dygraph.base import to_variable + +from hapi.utils import to_list + +__all__ = ['Loss', 'CrossEntropy', 'SoftmaxWithCrossEntropy'] + + +class Loss(object): + """ + Base class for loss, encapsulates loss logic and APIs + + Usage: + custom_loss = CustomLoss() + loss = custom_loss(inputs, labels) + """ + + def __init__(self, average=True): + super(Loss, self).__init__() + self.average = average + + def forward(self, outputs, labels): + raise NotImplementedError() + + def __call__(self, outputs, labels=None): + labels = to_list(labels) + if in_dygraph_mode() and labels: + labels = [to_variable(l) for l in labels] + losses = to_list(self.forward(to_list(outputs), labels)) + if self.average: + losses = [fluid.layers.reduce_mean(l) for l in losses] + else: + losses = [fluid.layers.reduce_sum(l) for l in losses] + return losses + + +class CrossEntropy(Loss): + """ + Args: + input (list[Variable]): Input tensor, the data type is float32, + float64, int32, int64. + label (list[Variable]): Label tensor, the data type is float32, + float64, int32, int64. + average (bool, optional): Indicate whether to average the loss, Default: True. + Returns: + list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels. + """ + + def __init__(self, average=True): + super(CrossEntropy, self).__init__() + + def forward(self, outputs, labels): + return [ + fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels) + ] + + +class SoftmaxWithCrossEntropy(Loss): + """ + this op combined softmax and cross entropy. + Args: + input (list[Variable]): Input tensor, the data type is float32, + float64, int32, int64. + label (list[Variable]): Label tensor, the data type is float32, + float64, int32, int64. + average (bool, optional): Indicate whether to average the loss, Default: True. + Returns: + list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels. + """ + + def __init__(self, average=True): + super(SoftmaxWithCrossEntropy, self).__init__() + + def forward(self, outputs, labels): + return [ + fluid.layers.softmax_with_cross_entropy( + o, l, return_softmax=False) for o, l in zip(outputs, labels) + ] diff --git a/hapi/model.py b/hapi/model.py index cde4ba6040be334f3ba902413ebf6953e2f35140..8c1c5216287b26645ca2d06178cb3b5176e7ab31 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -34,13 +34,16 @@ from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.base import role_maker from paddle.io import DataLoader, Dataset +from hapi.loss import Loss from hapi.distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized from hapi.metrics import Metric from hapi.callbacks import config_callbacks +from hapi.utils import to_list, to_numpy, flatten_list, restore_flatten_list __all__ = [ - 'Model', 'Loss', 'CrossEntropy', 'Input', 'set_device', - 'SoftmaxWithCrossEntropy' + 'Model', + 'Input', + 'set_device', ] @@ -63,49 +66,6 @@ def set_device(device): return place -def to_list(value): - if value is None: - return value - if isinstance(value, (list, tuple)): - return list(value) - return [value] - - -def to_numpy(var): - assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable" - if isinstance(var, fluid.core.VarBase): - return var.numpy() - t = global_scope().find_var(var.name).get_tensor() - return np.array(t) - - -def flatten_list(l): - assert isinstance(l, list), "not a list" - outl = [] - splits = [] - for sl in l: - assert isinstance(sl, list), "sub content not a list" - splits.append(len(sl)) - outl += sl - return outl, splits - - -def restore_flatten_list(l, splits): - outl = [] - for split in splits: - assert len(l) >= split, "list length invalid" - sl, l = l[:split], l[split:] - outl.append(sl) - return outl - - -def extract_args(func): - if hasattr(inspect, 'getfullargspec'): - return inspect.getfullargspec(func)[0] - else: - return inspect.getargspec(func)[0] - - class Input(fluid.dygraph.Layer): def __init__(self, shape=None, dtype=None, name=None): super(Input, self).__init__() @@ -117,47 +77,6 @@ class Input(fluid.dygraph.Layer): return fluid.data(self.name, shape=self.shape, dtype=self.dtype) -class Loss(object): - def __init__(self, average=True): - super(Loss, self).__init__() - self.average = average - - def forward(self, outputs, labels): - raise NotImplementedError() - - def __call__(self, outputs, labels=None): - labels = to_list(labels) - if in_dygraph_mode() and labels: - labels = [to_variable(l) for l in labels] - losses = to_list(self.forward(to_list(outputs), labels)) - if self.average: - losses = [fluid.layers.reduce_mean(l) for l in losses] - else: - losses = [fluid.layers.reduce_sum(l) for l in losses] - return losses - - -class CrossEntropy(Loss): - def __init__(self, average=True): - super(CrossEntropy, self).__init__() - - def forward(self, outputs, labels): - return [ - fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels) - ] - - -class SoftmaxWithCrossEntropy(Loss): - def __init__(self, average=True): - super(SoftmaxWithCrossEntropy, self).__init__() - - def forward(self, outputs, labels): - return [ - fluid.layers.softmax_with_cross_entropy( - o, l, return_softmax=False) for o, l in zip(outputs, labels) - ] - - class StaticGraphAdapter(object): def __init__(self, model): super(StaticGraphAdapter, self).__init__() @@ -576,15 +495,14 @@ class DynamicGraphAdapter(object): if labels is not None: labels = [to_variable(l) for l in to_list(labels)] if self._nranks > 1: - outputs = self.ddp_model.forward( - * [to_variable(x) for x in inputs]) + outputs = self.ddp_model.forward(*[to_variable(x) for x in inputs]) losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss = self.ddp_model.scale_loss(final_loss) final_loss.backward() self.ddp_model.apply_collective_grads() else: - outputs = self.model.forward(* [to_variable(x) for x in inputs]) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss.backward() @@ -593,9 +511,9 @@ class DynamicGraphAdapter(object): self.model.clear_gradients() metrics = [] for metric in self.model._metrics: - metric_outs = metric.add_metric_op(*(to_list(outputs) + to_list( - labels))) - m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)]) + metric_outs = metric.add_metric_op(*( + to_list(outputs) + to_list(labels))) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) return ([to_numpy(l) for l in losses], metrics) \ @@ -607,7 +525,7 @@ class DynamicGraphAdapter(object): inputs = to_list(inputs) if labels is not None: labels = [to_variable(l) for l in to_list(labels)] - outputs = self.model.forward(* [to_variable(x) for x in inputs]) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) if self.model._loss_function: losses = self.model._loss_function(outputs, labels) else: @@ -633,9 +551,9 @@ class DynamicGraphAdapter(object): self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_batch'] = samples - metric_outs = metric.add_metric_op(*(to_list(outputs) + to_list( - labels))) - m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)]) + metric_outs = metric.add_metric_op(*( + to_list(outputs) + to_list(labels))) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) # To be consistent with static graph @@ -874,8 +792,16 @@ class Model(fluid.dygraph.Layer): global _parallel_context_initialized if ParallelEnv().nranks > 1 and not _parallel_context_initialized: if fluid.in_dygraph_mode(): + main_prog_seed = fluid.default_main_program().random_seed + startup_prog_seed = fluid.default_startup_program( + ).random_seed fluid.disable_dygraph() fluid.enable_dygraph(self._place) + # enable_dygraph would create and switch to a new program, + # thus also copy seed to the new program + fluid.default_main_program().random_seed = main_prog_seed + fluid.default_startup_program( + ).random_seed = startup_prog_seed fluid.dygraph.parallel.prepare_context() else: prepare_distributed_context(self._place) @@ -1208,6 +1134,51 @@ class Model(fluid.dygraph.Layer): return outputs + def save_inference_model(self, + save_dir, + model_filename=None, + params_filename=None, + program_only=False): + """ + Save inference model must in static mode. + + Args: + dirname(str): The directory path to save the inference model. + model_filename(str|None): The name of file to save the inference program + itself. If is set None, a default filename + :code:`__model__` will be used. + params_filename(str|None): The name of file to save all related parameters. + If it is set None, parameters will be saved + in separate files . + program_only(bool): If True, It will save inference program only, and do not + save params of Program. + Default: False. + + Returns: + list: The fetch variables' name list + """ + assert not fluid.in_dygraph_mode( + ), 'Save inference model must in static mode!' + + prog = self._adapter._progs.get('test', None) + assert prog, \ + "Model is not ready, please call `model.prepare()` first" + + infer_prog = prog.clone(for_test=True) + + input_names = [v.name for v in self._adapter._input_vars['test']] + endpoints = self._adapter._endpoints['test']['output'] + + return fluid.io.save_inference_model( + save_dir, + input_names, + endpoints, + self._adapter._executor, + main_program=infer_prog, + model_filename=model_filename, + params_filename=params_filename, + program_only=program_only) + def _run_one_epoch(self, data_loader, callbacks, diff --git a/hapi/tests/dist_mnist.py b/hapi/tests/dist_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..ea407a8b3d1445b5a7b8ee9aaa634f2a5fb0ad12 --- /dev/null +++ b/hapi/tests/dist_mnist.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import os + +import numpy as np +import contextlib + +import paddle +from paddle import fluid + +from hapi.model import Model, Input, set_device +from hapi.loss import CrossEntropy +from hapi.vision.models import LeNet +from hapi.metrics import Accuracy +from hapi.callbacks import ProgBarLogger +from hapi.datasets import MNIST + + +class MnistDataset(MNIST): + def __init__(self, mode, return_label=True): + super(MnistDataset, self).__init__(mode=mode) + self.return_label = return_label + + def __getitem__(self, idx): + img = np.reshape(self.images[idx], [1, 28, 28]) + if self.return_label: + return img, np.array(self.labels[idx]).astype('int64') + return img, + + def __len__(self): + return len(self.images) + + +def get_predict_accuracy(pred, gt): + pred = np.argmax(pred, -1) + gt = np.array(gt) + + correct = pred[:, np.newaxis] == gt + + return np.sum(correct) / correct.shape[0] + + +class TestModel(unittest.TestCase): + def fit(self, dynamic): + device = set_device('gpu') + fluid.enable_dygraph(device) if dynamic else None + + im_shape = (-1, 784) + batch_size = 128 + + inputs = [Input(im_shape, 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] + + train_dataset = MnistDataset(mode='train') + val_dataset = MnistDataset(mode='test') + test_dataset = MnistDataset(mode='test', return_label=False) + + model = LeNet() + optim = fluid.optimizer.Momentum( + learning_rate=0.01, momentum=.9, parameter_list=model.parameters()) + loss = CrossEntropy() + model.prepare(optim, loss, Accuracy(), inputs, labels, device=device) + cbk = ProgBarLogger(50) + + model.fit(train_dataset, + val_dataset, + epochs=2, + batch_size=batch_size, + callbacks=cbk) + + eval_result = model.evaluate(val_dataset, batch_size=batch_size) + + output = model.predict( + test_dataset, batch_size=batch_size, stack_outputs=True) + + np.testing.assert_equal(output[0].shape[0], len(test_dataset)) + + acc = get_predict_accuracy(output[0], val_dataset.labels) + + np.testing.assert_allclose(acc, eval_result['acc']) + + def test_multiple_gpus_static(self): + self.fit(False) + + def test_multiple_gpus_dygraph(self): + self.fit(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_bert_dataloader.py b/hapi/tests/test_bert_dataloader.py similarity index 100% rename from tests/test_bert_dataloader.py rename to hapi/tests/test_bert_dataloader.py diff --git a/tests/test_callbacks.py b/hapi/tests/test_callbacks.py similarity index 100% rename from tests/test_callbacks.py rename to hapi/tests/test_callbacks.py diff --git a/tests/test_datasets.py b/hapi/tests/test_datasets.py similarity index 82% rename from tests/test_datasets.py rename to hapi/tests/test_datasets.py index 6adc9b667ac12c95fce0632ce2647db15e9fd470..857d037eb1195cfa4b8b835e7c1e63be5bd63d37 100644 --- a/tests/test_datasets.py +++ b/hapi/tests/test_datasets.py @@ -12,25 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -# when test, you should add hapi root path to the PYTHONPATH, -# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH - import unittest +import os import numpy as np +import tempfile +import shutil +import cv2 from hapi.datasets import * class TestFolderDatasets(unittest.TestCase): + def makedata(self): + self.data_dir = tempfile.mkdtemp() + for i in range(2): + sub_dir = os.path.join(self.data_dir, 'class_' + str(i)) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + for j in range(2): + fake_img = (np.random.random( + (32, 32, 3)) * 255).astype('uint8') + cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img) + def test_dataset(self): - dataset_folder = DatasetFolder('tests/test_data') + self.makedata() + dataset_folder = DatasetFolder(self.data_dir) for _ in dataset_folder: pass - assert len(dataset_folder) == 3 + assert len(dataset_folder) == 4 assert len(dataset_folder.classes) == 2 + shutil.rmtree(self.data_dir) + class TestMNISTTest(unittest.TestCase): def test_main(self): diff --git a/hapi/tests/test_distributed.py b/hapi/tests/test_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..7183d78fe11269dcbd64eb40accb241b26fce60b --- /dev/null +++ b/hapi/tests/test_distributed.py @@ -0,0 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import os +import time +import six +import copy +from argparse import ArgumentParser, REMAINDER +import paddle +import paddle.fluid as fluid + +from paddle.distributed.utils import * +import paddle.distributed.cloud_utils as cloud_utils + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + use_paddlecloud = False + started_port = None + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_rank = node_ips.index(node_ip) + + free_ports = None + if not use_paddlecloud and len(node_ips) <= 1 and started_port is None: + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + else: + started_port = 6070 + + free_ports = [ + x for x in range(started_port, started_port + len(selected_gpus)) + ] + return get_cluster(node_ips, node_ip, free_ports, selected_gpus) + + +def get_gpus(selected_gpus): + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if cuda_visible_devices is None or cuda_visible_devices == "": + selected_gpus = [x.strip() for x in selected_gpus.split(',')] + else: + cuda_visible_devices_list = cuda_visible_devices.split(',') + for x in selected_gpus.split(','): + assert x in cuda_visible_devices_list, "Can't find "\ + "your selected_gpus %s in CUDA_VISIBLE_DEVICES[%s]."\ + % (x, cuda_visible_devices) + selected_gpus = [ + cuda_visible_devices_list.index(x.strip()) + for x in selected_gpus.split(',') + ] + return selected_gpus + + +def start_local_trainers(cluster, + pod, + training_script, + training_script_args, + log_dir=None): + current_env = copy.copy(os.environ.copy()) + #paddle broadcast ncclUniqueId use socket, and + #proxy maybe make trainers unreachable, so delete them. + #if we set them to "", grpc will log error message "bad uri" + #so just delete them. + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + for idx, t in enumerate(pod.trainers): + proc_env = { + "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]), + "PADDLE_TRAINER_ID": "%d" % t.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) + } + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + cmd = "python -m coverage run --branch -p " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = t.rank + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + +class TestMultipleGpus(unittest.TestCase): + def test_mnist_2gpu(self): + if fluid.core.get_cuda_device_count() == 0: + return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + + procs = start_local_trainers( + cluster, + pod, + training_script='dist_mnist.py', + training_script_args=[]) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_nranks()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/hapi/tests/test_logger.py b/hapi/tests/test_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..561253dd32f3e88e17656e23d1338b2e4ad74200 --- /dev/null +++ b/hapi/tests/test_logger.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import os +import numpy as np +import shutil +import tempfile + +from hapi.logger import setup_logger + + +class TestSetupLogger(unittest.TestCase): + def setUp(self): + self.save_dir = tempfile.mkdtemp() + self.save_file = os.path.join(self.save_dir, 'logger.txt') + + def tearDown(self): + shutil.rmtree(self.save_dir) + + def logger(self, output=None): + setup_logger(output=output) + + def test_logger_no_output(self): + self.logger() + + def test_logger_dir(self): + self.logger(self.save_dir) + + def test_logger_file(self): + self.logger(self.save_file) + + +if __name__ == '__main__': + unittest.main() diff --git a/hapi/tests/test_loss.py b/hapi/tests/test_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf8a48a47d3b36f49df4437fa12dcee40519a7c --- /dev/null +++ b/hapi/tests/test_loss.py @@ -0,0 +1,112 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import os +import six +import numpy as np +import shutil +import copy + +import paddle +from paddle import fluid + +from hapi.model import Model, Input +from hapi.vision.models import resnet18 +from hapi.loss import CrossEntropy, SoftmaxWithCrossEntropy + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +def randomize_probability(batch_size, class_num, dtype='float32'): + prob = np.random.uniform( + 0.1, 1.0, size=(batch_size, class_num)).astype(dtype) + prob_sum = prob.sum(axis=1) + for i in six.moves.xrange(len(prob)): + prob[i] /= prob_sum[i] + return prob + + +def numpy_ce(x, label): + return np.asmatrix( + [[-np.log(x[i][label[i][0]])] for i in range(x.shape[0])], + dtype="float32").mean() + + +class TestLoss(unittest.TestCase): + def test_cross_entropy(self): + class_num = 100 + batch_size = 128 + inputs = [randomize_probability(128, class_num) for _ in range(2)] + + labels = [ + np.random.randint( + 0, class_num, (batch_size, 1), dtype="int64") for _ in range(2) + ] + + gt_out = [numpy_ce(inputs[i], labels[i]) for i in range(2)] + + fluid.enable_dygraph() + cross_entropy = CrossEntropy() + out = cross_entropy( + [fluid.dygraph.to_variable(x) for x in inputs], + [fluid.dygraph.to_variable(label) for label in labels]) + out = [o.numpy() for o in out] + + for o, g in zip(out, gt_out): + np.testing.assert_allclose(o, g, atol=1e-5) + + def test_soft_cross_entronpy(self): + class_num = 100 + batch_size = 128 + + inputs = [randomize_probability(128, class_num) for _ in range(2)] + + labels = [ + np.random.randint( + 0, class_num, (batch_size, 1), dtype="int64") for _ in range(2) + ] + + fluid.enable_dygraph() + softmax_cross_entropy = SoftmaxWithCrossEntropy() + + softmax_cross_entropy( + [fluid.dygraph.to_variable(x) for x in inputs], + [fluid.dygraph.to_variable(label) for label in labels]) + + softmax_cross_entropy = SoftmaxWithCrossEntropy(average=False) + + inputs = [randomize_probability(128, class_num)] + + labels = [ + np.random.randint( + 0, class_num, (batch_size, 1), dtype="int64") + ] + + softmax_cross_entropy([fluid.dygraph.to_variable(x) for x in inputs], + fluid.dygraph.to_variable(labels[0])) + + +if __name__ == '__main__': + unittest.main() diff --git a/hapi/tests/test_model.py b/hapi/tests/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..16b4368e1afeca4f03775994426669981f419675 --- /dev/null +++ b/hapi/tests/test_model.py @@ -0,0 +1,229 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import os +import cv2 +import numpy as np + +import paddle +from paddle import fluid +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.dygraph.container import Sequential +from paddle.io import BatchSampler, DataLoader + +from hapi.model import Model, Input, set_device +from hapi.loss import Loss +from hapi.metrics import Accuracy +from hapi.datasets import MNIST +from hapi.vision.models import LeNet +from hapi.download import get_weights_path_from_url + + +class LeNetDygraph(fluid.dygraph.Layer): + """LeNet model from + `"LeCun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324.`_ + + Args: + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 10. + classifier_activation (str): activation for the last fc layer. Default: 'softmax'. + """ + + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(LeNetDygraph, self).__init__() + self.num_classes = num_classes + self.features = Sequential( + Conv2D( + 1, 6, 3, stride=1, padding=1), + Pool2D(2, 'max', 2), + Conv2D( + 6, 16, 5, stride=1, padding=0), + Pool2D(2, 'max', 2)) + + if num_classes > 0: + self.fc = Sequential( + Linear(400, 120), + Linear(120, 84), + Linear( + 84, 10, act=classifier_activation)) + + def forward(self, inputs): + x = self.features(inputs) + + if self.num_classes > 0: + x = fluid.layers.flatten(x, 1) + x = self.fc(x) + return x + + +class MnistDataset(MNIST): + def __init__(self, mode, return_label=True): + super(MnistDataset, self).__init__(mode=mode) + self.return_label = return_label + + def __getitem__(self, idx): + img = np.reshape(self.images[idx], [1, 28, 28]) + if self.return_label: + return img, np.array(self.labels[idx]).astype('int64') + return img, + + def __len__(self): + return len(self.images) + + +def get_predict_accuracy(pred, gt): + pred = np.argmax(pred, -1) + gt = np.array(gt) + + correct = pred[:, np.newaxis] == gt + return np.sum(correct) / correct.shape[0] + + +def low_level_lenet_dygraph_train(model, dataloader): + optim = fluid.optimizer.Adam( + learning_rate=0.001, parameter_list=model.parameters()) + model.train() + for inputs, labels in dataloader: + outputs = model(inputs) + loss = fluid.layers.cross_entropy(outputs, labels) + avg_loss = fluid.layers.reduce_sum(loss) + avg_loss.backward() + optim.minimize(avg_loss) + model.clear_gradients() + + +def low_level_dynamic_evaluate(model, dataloader): + with fluid.dygraph.no_grad(): + model.eval() + cnt = 0 + for inputs, labels in dataloader: + outputs = model(inputs) + + cnt += (np.argmax(outputs.numpy(), -1)[:, np.newaxis] == + labels.numpy()).astype('int').sum() + + return cnt / len(dataloader.dataset) + + +class TestEvaluatePredict(unittest.TestCase): + def setUp(self): + self.device = set_device('gpu') + self.train_dataset = MnistDataset(mode='train') + self.val_dataset = MnistDataset(mode='test') + self.test_dataset = MnistDataset(mode='test', return_label=False) + + fluid.enable_dygraph(self.device) + train_dataloader = fluid.io.DataLoader( + self.train_dataset, places=self.device, batch_size=64) + val_dataloader = fluid.io.DataLoader( + self.val_dataset, places=self.device, batch_size=64) + self.lenet_dygraph = LeNetDygraph() + low_level_lenet_dygraph_train(self.lenet_dygraph, train_dataloader) + self.acc1 = low_level_dynamic_evaluate(self.lenet_dygraph, + val_dataloader) + + def evaluate(self, dynamic): + fluid.enable_dygraph(self.device) if dynamic else None + + inputs = [Input([-1, 1, 28, 28], 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] + + if fluid.in_dygraph_mode(): + feed_list = None + else: + feed_list = [x.forward() for x in inputs + labels] + + self.train_dataloader = fluid.io.DataLoader( + self.train_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + self.val_dataloader = fluid.io.DataLoader( + self.val_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + self.test_dataloader = fluid.io.DataLoader( + self.test_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + + model = LeNet() + model.load_dict(self.lenet_dygraph.state_dict()) + model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels) + + result = model.evaluate(self.val_dataloader) + + np.testing.assert_allclose(result['acc'], self.acc1) + + def predict(self, dynamic): + fluid.enable_dygraph(self.device) if dynamic else None + + inputs = [Input([-1, 1, 28, 28], 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] + + if fluid.in_dygraph_mode(): + feed_list = None + else: + feed_list = [x.forward() for x in inputs + labels] + + self.train_dataloader = fluid.io.DataLoader( + self.train_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + self.val_dataloader = fluid.io.DataLoader( + self.val_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + self.test_dataloader = fluid.io.DataLoader( + self.test_dataset, + places=self.device, + batch_size=64, + feed_list=feed_list) + + model = LeNet() + model.load_dict(self.lenet_dygraph.state_dict()) + model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels) + + output = model.predict(self.test_dataloader, stack_outputs=True) + + np.testing.assert_equal(output[0].shape[0], len(self.test_dataset)) + + acc = get_predict_accuracy(output[0], self.val_dataset.labels) + + np.testing.assert_allclose(acc, self.acc1) + + def test_evaluate_dygraph(self): + self.evaluate(True) + + def test_evaluate_static(self): + self.evaluate(False) + + def test_predict_dygraph(self): + self.predict(True) + + def test_predict_static(self): + self.predict(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_progressbar.py b/hapi/tests/test_progressbar.py similarity index 100% rename from tests/test_progressbar.py rename to hapi/tests/test_progressbar.py diff --git a/hapi/tests/test_save_inference_model.py b/hapi/tests/test_save_inference_model.py new file mode 100644 index 0000000000000000000000000000000000000000..51d8cb533c7d5ec638a68575a12ea7cb79a8d9cf --- /dev/null +++ b/hapi/tests/test_save_inference_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import os +import numpy as np +import shutil +import tempfile + +import paddle +from paddle import fluid + +from hapi.model import Model, Input +from hapi.vision.models import resnet18 + + +class TestSaveInferenceModel(unittest.TestCase): + def tearDown(self): + shutil.rmtree(self.save_dir) + + def export_deploy_model(self): + model = resnet18() + + inputs = [Input([None, 3, 224, 224], 'float32', name='image')] + + model.prepare(inputs=inputs) + + self.save_dir = tempfile.mkdtemp() + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + model.save_inference_model(self.save_dir) + + place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + [inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + dirname=self.save_dir, executor=exe)) + tensor_img = np.array( + np.random.random((1, 3, 224, 224)), dtype=np.float32) + ori_results = model.test_batch(tensor_img) + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + np.testing.assert_allclose(results, ori_results) + + def test_save_inference_model(self): + self.export_deploy_model() + + +if __name__ == '__main__': + unittest.main() diff --git a/hapi/tests/test_transforms.py b/hapi/tests/test_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc8d5067b7b1cb6501d5f1ed9dea69624dea2db --- /dev/null +++ b/hapi/tests/test_transforms.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# when test, you should add hapi root path to the PYTHONPATH, +# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH +import unittest +import os +import tempfile +import cv2 +import shutil +import numpy as np + +from hapi.datasets import DatasetFolder +import hapi.vision.transforms as transforms + + +class TestTransforms(unittest.TestCase): + def setUp(self): + self.data_dir = tempfile.mkdtemp() + for i in range(2): + sub_dir = os.path.join(self.data_dir, 'class_' + str(i)) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + for j in range(2): + if j == 0: + fake_img = (np.random.random( + (280, 350, 3)) * 255).astype('uint8') + else: + fake_img = (np.random.random( + (400, 300, 3)) * 255).astype('uint8') + cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img) + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def do_transform(self, trans): + dataset_folder = DatasetFolder(self.data_dir, transform=trans) + + for _ in dataset_folder: + pass + + def test_trans_all(self): + normalize = transforms.Normalize( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375]) + trans = transforms.Compose([ + transforms.RandomResizedCrop(224), transforms.GaussianNoise(), + transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, + hue=0.4), transforms.RandomHorizontalFlip(), + transforms.Permute(mode='CHW'), normalize + ]) + + self.do_transform(trans) + + def test_trans_resize(self): + trans = transforms.Compose([ + transforms.Resize(300, [0, 1]), + transforms.RandomResizedCrop((280, 280)), + transforms.Resize(280, [0, 1]), + transforms.Resize((256, 200)), + transforms.Resize((180, 160)), + transforms.CenterCrop(128), + transforms.CenterCrop((128, 128)), + ]) + self.do_transform(trans) + + def test_trans_centerCrop(self): + trans = transforms.Compose([ + transforms.CenterCropResize(224), + transforms.CenterCropResize(128, 160), + ]) + self.do_transform(trans) + + def test_flip(self): + trans = transforms.Compose([ + transforms.RandomHorizontalFlip(1.0), + transforms.RandomHorizontalFlip(0.0), + transforms.RandomVerticalFlip(0.0), + transforms.RandomVerticalFlip(1.0), + ]) + self.do_transform(trans) + + def test_color_jitter(self): + trans = transforms.BatchCompose([ + transforms.BrightnessTransform(0.0), + transforms.HueTransform(0.0), + transforms.SaturationTransform(0.0), + transforms.ContrastTransform(0.0), + ]) + self.do_transform(trans) + + def test_exception(self): + trans = transforms.Compose([transforms.Resize(-1)]) + + trans_batch = transforms.BatchCompose([transforms.Resize(-1)]) + + with self.assertRaises(Exception): + self.do_transform(trans) + + with self.assertRaises(Exception): + self.do_transform(trans_batch) + + with self.assertRaises(ValueError): + transforms.ContrastTransform(-1.0) + + with self.assertRaises(ValueError): + transforms.SaturationTransform(-1.0), + + with self.assertRaises(ValueError): + transforms.HueTransform(-1.0) + + with self.assertRaises(ValueError): + transforms.BrightnessTransform(-1.0) + + def test_info(self): + str(transforms.Compose([transforms.Resize((224, 224))])) + str(transforms.BatchCompose([transforms.Resize((224, 224))])) + + +if __name__ == '__main__': + unittest.main() diff --git a/hapi/tests/test_vison_models.py b/hapi/tests/test_vison_models.py new file mode 100644 index 0000000000000000000000000000000000000000..05d3a10b34573cbf095c8e95f149f26edbbdfeec --- /dev/null +++ b/hapi/tests/test_vison_models.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import hapi.vision.models as models +from hapi.model import Input + + +class TestVisonModels(unittest.TestCase): + def models_infer(self, arch, pretrained=False, batch_norm=False): + + x = np.array(np.random.random((2, 3, 224, 224)), dtype=np.float32) + if batch_norm: + model = models.__dict__[arch](pretrained=pretrained, + batch_norm=True) + else: + model = models.__dict__[arch](pretrained=pretrained) + inputs = [Input([None, 3, 224, 224], 'float32', name='image')] + + model.prepare(inputs=inputs) + + model.test_batch(x) + + def test_mobilenetv2_pretrained(self): + self.models_infer('mobilenet_v2', pretrained=True) + + def test_mobilenetv1(self): + self.models_infer('mobilenet_v1') + + def test_vgg11(self): + self.models_infer('vgg11') + + def test_vgg13(self): + self.models_infer('vgg13') + + def test_vgg16(self): + self.models_infer('vgg16') + + def test_vgg16_bn(self): + self.models_infer('vgg16', batch_norm=True) + + def test_vgg19(self): + self.models_infer('vgg19') + + def test_resnet18(self): + self.models_infer('resnet18') + + def test_resnet34(self): + self.models_infer('resnet34') + + def test_resnet50(self): + self.models_infer('resnet50') + + def test_resbet101(self): + self.models_infer('resnet101') + + def test_resbet152(self): + self.models_infer('resnet152') + + def test_darknet53(self): + self.models_infer('darknet53') + + def test_lenet(self): + lenet = models.__dict__['LeNet']() + + inputs = [Input([None, 1, 28, 28], 'float32', name='x')] + lenet.prepare(inputs=inputs) + + x = np.array(np.random.random((2, 1, 28, 28)), dtype=np.float32) + lenet.test_batch(x) + + +if __name__ == '__main__': + unittest.main() diff --git a/hapi/utils.py b/hapi/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de928945dc68a1800c3cd9b14aaa0659e50c9945 --- /dev/null +++ b/hapi/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import numpy as np + +from paddle import fluid +from paddle.fluid.framework import Variable +from paddle.fluid.executor import global_scope + + +def to_list(value): + if value is None: + return value + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +def to_numpy(var): + assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable" + if isinstance(var, fluid.core.VarBase): + return var.numpy() + t = global_scope().find_var(var.name).get_tensor() + return np.array(t) + + +def flatten_list(l): + assert isinstance(l, list), "not a list" + outl = [] + splits = [] + for sl in l: + assert isinstance(sl, list), "sub content not a list" + splits.append(len(sl)) + outl += sl + return outl, splits + + +def restore_flatten_list(l, splits): + outl = [] + for split in splits: + assert len(l) >= split, "list length invalid" + sl, l = l[:split], l[split:] + outl.append(sl) + return outl + + +def extract_args(func): + if hasattr(inspect, 'getfullargspec'): + return inspect.getfullargspec(func)[0] + else: + return inspect.getargspec(func)[0] \ No newline at end of file diff --git a/hapi/vision/models/__init__.py b/hapi/vision/models/__init__.py index d444cd6627e8228a796c29cd7396d459e10cc4c7..4150cb5d88278e7e29575b771f956c3ac8abbd4d 100644 --- a/hapi/vision/models/__init__.py +++ b/hapi/vision/models/__init__.py @@ -17,15 +17,18 @@ from . import vgg from . import mobilenetv1 from . import mobilenetv2 from . import darknet +from . import lenet from .resnet import * from .mobilenetv1 import * from .mobilenetv2 import * from .vgg import * from .darknet import * +from .lenet import * __all__ = resnet.__all__ \ + vgg.__all__ \ + mobilenetv1.__all__ \ + mobilenetv2.__all__ \ - + darknet.__all__ + + darknet.__all__\ + + lenet.__all__ diff --git a/hapi/vision/models/darknet.py b/hapi/vision/models/darknet.py index 5525b6c0489c993669de5d675b25518dc74a6ca6..582b4c56cee3f8aeb450db54f3a6551dc8a04689 100755 --- a/hapi/vision/models/darknet.py +++ b/hapi/vision/models/darknet.py @@ -20,14 +20,15 @@ from paddle.fluid.regularizer import L2Decay from paddle.fluid.dygraph.nn import Conv2D, BatchNorm, Pool2D, Linear from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = ['DarkNet', 'darknet53'] # {num_layers: (url, md5)} -pretrain_infos = { - 53: ('https://paddle-hapi.bj.bcebos.com/models/darknet53.pdparams', - 'ca506a90e2efecb9a2093f8ada808708') +model_urls = { + 'darknet53': + ('https://paddle-hapi.bj.bcebos.com/models/darknet53.pdparams', + 'ca506a90e2efecb9a2093f8ada808708') } @@ -213,16 +214,15 @@ class DarkNet(Model): return out -def _darknet(num_layers=53, pretrained=False, **kwargs): +def _darknet(arch, num_layers=53, pretrained=False, **kwargs): model = DarkNet(num_layers, **kwargs) if pretrained: - assert num_layers in pretrain_infos.keys(), \ - "DarkNet{} do not have pretrained weights now, " \ - "pretrained should be set as False".format(num_layers) - weight_path = get_weights_path(*(pretrain_infos[num_layers])) + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(*(model_urls[arch])) assert weight_path.endswith('.pdparams'), \ "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + model.load(weight_path) return model @@ -234,4 +234,4 @@ def darknet53(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet, default True. """ - return _darknet(53, pretrained, **kwargs) + return _darknet('darknet53', 53, pretrained, **kwargs) diff --git a/hapi/vision/models/mobilenetv1.py b/hapi/vision/models/mobilenetv1.py index 31c0acbee2fdc107b0d776605c296c2c9296bcfd..b725afac14c25c70008a5ef3167e0b18f3f9b521 100644 --- a/hapi/vision/models/mobilenetv1.py +++ b/hapi/vision/models/mobilenetv1.py @@ -20,7 +20,7 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = ['MobileNetV1', 'mobilenet_v1'] @@ -267,11 +267,11 @@ def _mobilenet(arch, pretrained=False, **kwargs): if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( arch) - weight_path = get_weights_path(model_urls[arch][0], - model_urls[arch][1]) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) assert weight_path.endswith( '.pdparams'), "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + model.load(weight_path) return model diff --git a/hapi/vision/models/mobilenetv2.py b/hapi/vision/models/mobilenetv2.py index d624625bcda1b763a0b3e511b6146776245e2fd5..c9591b22a505d3bd7a60c12597919bda182f88b5 100644 --- a/hapi/vision/models/mobilenetv2.py +++ b/hapi/vision/models/mobilenetv2.py @@ -19,7 +19,7 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = ['MobileNetV2', 'mobilenet_v2'] @@ -241,11 +241,11 @@ def _mobilenet(arch, pretrained=False, **kwargs): if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( arch) - weight_path = get_weights_path(model_urls[arch][0], - model_urls[arch][1]) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) assert weight_path.endswith( '.pdparams'), "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + model.load(weight_path) return model diff --git a/hapi/vision/models/resnet.py b/hapi/vision/models/resnet.py index ac0944ee651224b106db71d0c87e9e5c29fd14d9..1adb085c7d0e26fb89a036303812a537d028cf3f 100644 --- a/hapi/vision/models/resnet.py +++ b/hapi/vision/models/resnet.py @@ -23,7 +23,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.container import Sequential from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = [ 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' @@ -267,11 +267,11 @@ def _resnet(arch, Block, depth, pretrained, **kwargs): if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( arch) - weight_path = get_weights_path(model_urls[arch][0], - model_urls[arch][1]) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) assert weight_path.endswith( '.pdparams'), "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + model.load(weight_path) return model diff --git a/hapi/vision/models/vgg.py b/hapi/vision/models/vgg.py index 41cf34eddf7d4d379f9ea3a6bc5490f9763919dc..0cd7cb79e514873991382b7a577b2b2a8d204fba 100644 --- a/hapi/vision/models/vgg.py +++ b/hapi/vision/models/vgg.py @@ -18,7 +18,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.container import Sequential from hapi.model import Model -from hapi.download import get_weights_path +from hapi.download import get_weights_path_from_url __all__ = [ 'VGG', @@ -128,11 +128,11 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( arch) - weight_path = get_weights_path(model_urls[arch][0], - model_urls[arch][1]) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) assert weight_path.endswith( '.pdparams'), "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + model.load(weight_path) return model diff --git a/hapi/vision/transforms/transforms.py b/hapi/vision/transforms/transforms.py index 87e49862489c1d9284b9d3e6d018e0de2f183bcb..b71b2571bafa23ae3ef58ec943d4e147749332f7 100644 --- a/hapi/vision/transforms/transforms.py +++ b/hapi/vision/transforms/transforms.py @@ -29,8 +29,10 @@ import traceback from . import functional as F if sys.version_info < (3, 3): + Sequence = collections.Sequence Iterable = collections.Iterable else: + Sequence = collections.abc.Sequence Iterable = collections.abc.Iterable __all__ = [ @@ -54,20 +56,45 @@ __all__ = [ class Compose(object): - """Composes several transforms together. + """ + Composes several transforms together use for composing list of transforms + together for a dataset transform. Args: transforms (list of ``Transform`` objects): list of transforms to compose. + Returns: + A compose object which is callable, __call__ for this Compose + object will call each given :attr:`transforms` sequencely. + + Examples: + + .. code-block:: python + + from hapi.datasets import Flowers + from hapi.vision.transforms import Compose, ColorJitter, Resize + + transform = Compose([ColorJitter(), Resize(size=608)]) + flowers = Flowers(mode='test', transform=transform) + + for i in range(10): + sample = flowers[i] + print(sample[0].shape, sample[1]) + """ def __init__(self, transforms): self.transforms = transforms - def __call__(self, data): + def __call__(self, *data): for f in self.transforms: try: - data = f(data) + # multi-fileds in a sample + if isinstance(data, Sequence): + data = f(*data) + # single field in a sample, call transform directly + else: + data = f(data) except Exception as e: stack_info = traceback.format_exc() print("fail to perform transform [{}] with error: " diff --git a/mnist.py b/mnist.py index 4e6240c2d5783b820a8f33f3d75064bd1d495693..397c51e2b796ee5fa5722d1865378a603c66ef1a 100644 --- a/mnist.py +++ b/mnist.py @@ -26,7 +26,8 @@ from paddle.fluid.optimizer import Momentum from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from hapi.datasets.mnist import MNIST as MnistDataset -from hapi.model import Model, CrossEntropy, Input, set_device +from hapi.model import Model, Input, set_device +from hapi.loss import CrossEntropy from hapi.metrics import Accuracy diff --git a/tests/test_data/class_a/ILSVRC2012_val_00000293.JPEG b/tests/test_data/class_a/ILSVRC2012_val_00000293.JPEG deleted file mode 100644 index 1b332471a78cbb3e362a0871d8e2dfad14320910..0000000000000000000000000000000000000000 Binary files a/tests/test_data/class_a/ILSVRC2012_val_00000293.JPEG and /dev/null differ diff --git a/tests/test_data/class_a/ILSVRC2012_val_00002138.JPEG b/tests/test_data/class_a/ILSVRC2012_val_00002138.JPEG deleted file mode 100644 index 251f84450c8734d9683ad2bfba59dcf0ff2c9109..0000000000000000000000000000000000000000 Binary files a/tests/test_data/class_a/ILSVRC2012_val_00002138.JPEG and /dev/null differ diff --git a/tests/test_data/class_b/ILSVRC2012_val_00000236.JPEG b/tests/test_data/class_b/ILSVRC2012_val_00000236.JPEG deleted file mode 100644 index a62f618980125faa60af6649d1b88799bde25228..0000000000000000000000000000000000000000 Binary files a/tests/test_data/class_b/ILSVRC2012_val_00000236.JPEG and /dev/null differ diff --git a/tests/test_model.py b/tests/test_model.py deleted file mode 100644 index 3aea2d1353e2e414d35e9b6714bdb0985d1249c7..0000000000000000000000000000000000000000 --- a/tests/test_model.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import division -from __future__ import print_function - -# when test, you should add hapi root path to the PYTHONPATH, -# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH - -import unittest - -import os - -import numpy as np -import contextlib - -import paddle -from paddle import fluid -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from paddle.io import BatchSampler, DataLoader - -from hapi.model import Model, CrossEntropy, Input, Loss, set_device -from hapi.metrics import Accuracy -from hapi.callbacks import ProgBarLogger -from hapi.datasets import MNIST as MnistDataset - - -class SimpleImgConvPool(fluid.dygraph.Layer): - def __init__(self, - num_channels, - num_filters, - filter_size, - pool_size, - pool_stride, - pool_padding=0, - pool_type='max', - global_pooling=False, - conv_stride=1, - conv_padding=0, - conv_dilation=1, - conv_groups=None, - act=None, - use_cudnn=False, - param_attr=None, - bias_attr=None): - super(SimpleImgConvPool, self).__init__('SimpleConv') - - self._conv2d = Conv2D( - num_channels=num_channels, - num_filters=num_filters, - filter_size=filter_size, - stride=conv_stride, - padding=conv_padding, - dilation=conv_dilation, - groups=conv_groups, - param_attr=None, - bias_attr=None, - use_cudnn=use_cudnn) - - self._pool2d = Pool2D( - pool_size=pool_size, - pool_type=pool_type, - pool_stride=pool_stride, - pool_padding=pool_padding, - global_pooling=global_pooling, - use_cudnn=use_cudnn) - - def forward(self, inputs): - x = self._conv2d(inputs) - x = self._pool2d(x) - return x - - -class MNIST(Model): - def __init__(self): - super(MNIST, self).__init__() - self._simple_img_conv_pool_1 = SimpleImgConvPool( - 1, 20, 5, 2, 2, act="relu") - - self._simple_img_conv_pool_2 = SimpleImgConvPool( - 20, 50, 5, 2, 2, act="relu") - - pool_2_shape = 50 * 4 * 4 - SIZE = 10 - scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 - self._fc = Linear( - 800, - 10, - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.NormalInitializer( - loc=0.0, scale=scale)), - act="softmax") - - def forward(self, inputs): - inputs = fluid.layers.reshape(inputs, [-1, 1, 28, 28]) - x = self._simple_img_conv_pool_1(inputs) - x = self._simple_img_conv_pool_2(x) - x = fluid.layers.flatten(x, axis=1) - x = self._fc(x) - return x - - -class MLP(Model): - def __init__(self): - super(MLP, self).__init__() - SIZE = 10 - self._fc1 = Linear(784, 200, act="relu") - self._fc2 = Linear(200, 200, act="relu") - self._fc3 = Linear(200, 200, act="relu") - self._fc4 = Linear(200, 10, act="softmax") - self._fc5 = Linear(200, 10, act="softmax") - - def forward(self, inputs): - x1 = self._fc1(inputs) - x2 = self._fc2(x1) - x3 = self._fc3(x2) - o1 = self._fc5(x3) - o2 = self._fc4(x2) - return o1, o2 - - -class MyCrossEntropy(Loss): - def __init__(self, average=True): - super(MyCrossEntropy, self).__init__() - - def forward(self, outputs, labels): - loss1 = fluid.layers.cross_entropy(outputs[0], labels[0]) - loss2 = fluid.layers.cross_entropy(outputs[1], labels[0]) - return [loss1, loss2] - - -class TestMnistDataset(MnistDataset): - def __init__(self): - super(TestMnistDataset, self).__init__(mode='test') - - def __getitem__(self, idx): - return self.images[idx], - - def __len__(self): - return len(self.images) - - -def get_predict_accuracy(pred, gt): - pred = np.argmax(pred, -1) - gt = np.array(gt) - - correct = pred[:, np.newaxis] == gt - - return np.sum(correct) / correct.shape[0] - - -class TestModel(unittest.TestCase): - def fit(self, dynamic, is_mlp=False): - device = set_device('gpu') - fluid.enable_dygraph(device) if dynamic else None - - im_shape = (-1, 784) - batch_size = 128 - - inputs = [Input(im_shape, 'float32', name='image')] - labels = [Input([None, 1], 'int64', name='label')] - - train_dataset = MnistDataset(mode='train') - val_dataset = MnistDataset(mode='test') - test_dataset = TestMnistDataset() - - model = MNIST() if not is_mlp else MLP() - optim = fluid.optimizer.Momentum( - learning_rate=0.01, momentum=.9, parameter_list=model.parameters()) - loss = CrossEntropy() if not is_mlp else MyCrossEntropy() - model.prepare(optim, loss, Accuracy(), inputs, labels, device=device) - cbk = ProgBarLogger(50) - - model.fit(train_dataset, - val_dataset, - epochs=2, - batch_size=batch_size, - callbacks=cbk) - - eval_result = model.evaluate(val_dataset, batch_size=batch_size) - - output = model.predict( - test_dataset, batch_size=batch_size, stack_outputs=True) - - np.testing.assert_equal(output[0].shape[0], len(test_dataset)) - - acc = get_predict_accuracy(output[0], val_dataset.labels) - - np.testing.assert_allclose(acc, eval_result['acc']) - - def test_fit_static(self): - self.fit(False) - - def test_fit_dygraph(self): - self.fit(True) - - def test_fit_static_multi_loss(self): - self.fit(False, MyCrossEntropy()) - - def test_fit_dygraph_multi_loss(self): - self.fit(True, MyCrossEntropy()) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_transforms.py b/tests/test_transforms.py deleted file mode 100644 index 4471470d62ee1ba88ed6bb1bcebce6252908dc03..0000000000000000000000000000000000000000 --- a/tests/test_transforms.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# when test, you should add hapi root path to the PYTHONPATH, -# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH -import unittest - -from hapi.datasets import DatasetFolder -import hapi.vision.transforms as transforms - - -class TestTransforms(unittest.TestCase): - def do_transform(self, trans): - dataset_folder = DatasetFolder('tests/test_data', transform=trans) - - for _ in dataset_folder: - pass - - def test_trans0(self): - normalize = transforms.Normalize( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375]) - trans = transforms.Compose([ - transforms.RandomResizedCrop(224), transforms.GaussianNoise(), - transforms.ColorJitter( - brightness=0.4, contrast=0.4, saturation=0.4, - hue=0.4), transforms.RandomHorizontalFlip(), - transforms.Permute(mode='CHW'), normalize - ]) - - self.do_transform(trans) - - def test_trans1(self): - trans = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - ]) - self.do_transform(trans) - - def test_trans2(self): - trans = transforms.Compose([transforms.CenterCropResize(224)]) - self.do_transform(trans) - - -if __name__ == '__main__': - unittest.main()