diff --git a/image_classification/README.MD b/image_classification/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..2f450d5a58f17a6a1c3e7d294c8e5d3bae236d4a --- /dev/null +++ b/image_classification/README.MD @@ -0,0 +1,74 @@ +# 高级api图像分类 + +## 数据集准备 +在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示: + +```bash +/path/to/imagenet + train + n01440764 + xxx.jpg + ... + n01443537 + xxx.jpg + ... + ... + val + n01440764 + xxx.jpg + ... + n01443537 + xxx.jpg + ... + ... +``` + + +## 训练 +### 单卡训练 +执行如下命令进行训练 +```bash +python -u main.py --arch resnet50 /path/to/imagenet -d +``` + +### 多卡训练 +执行如下命令进行训练 +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet +``` + +## 预测 + +### 单卡预测 +执行如下命令进行预测 +```bash +python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet +``` + +### 多卡预测 +执行如下命令进行多卡预测 +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet +``` + + +## 参数说明 + + +* **arch**: 要训练或预测的模型名称 +* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu' +* **dynamic**: 是否使用动态图模式训练 +* **epoch**: 训练的轮数,默认值:120 +* **learning-rate**: 学习率,默认值:0.1 +* **batch-size**: 每张卡的batch size,默认值:64 +* **output-dir**: 模型文件保存的文件夹,默认值:'output' +* **num-workers**: dataloader的进程数,默认值:4 +* **resume**: 恢复训练的模型路径,默认值:None +* **eval-only**: 仅仅进行预测,默认值:False + + +## 模型 + +| 模型 | top1 acc | top5 acc | +| --- | --- | --- | +| ResNet50 | 76.28 | 93.04 | diff --git a/image_classification/imagenet_dataset.py b/image_classification/imagenet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0c019cd1d1a2fd697d401c4de6d3a79d476914c2 --- /dev/null +++ b/image_classification/imagenet_dataset.py @@ -0,0 +1,101 @@ +import os +import cv2 +import math +import random +import numpy as np +from paddle.fluid.io import Dataset + + +def center_crop_resize(img): + h, w = img.shape[:2] + c = int(224 / 256 * min((h, w))) + i = (h + 1 - c) // 2 + j = (w + 1 - c) // 2 + img = img[i:i + c, j:j + c, :] + return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR) + + +def random_crop_resize(img): + height, width = img.shape[:2] + area = height * width + + for attempt in range(10): + target_area = random.uniform(0.08, 1.) * area + log_ratio = (math.log(3 / 4), math.log(4 / 3)) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= width and h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + img = img[i:i + h, j:j + w, :] + return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR) + + return center_crop_resize(img) + + +def random_flip(img): + if np.random.randint(0, 2) == 1: + img = img[:, ::-1, :] + return img + + +def normalize_permute(img): + # transpose and convert to RGB from BGR + img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...] + mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + std = np.array([58.395, 57.120, 57.375], dtype=np.float32) + invstd = 1. / std + for v, m, s in zip(img, mean, invstd): + v.__isub__(m).__imul__(s) + return img + + +def compose(functions): + def process(sample): + img, label = sample + for fn in functions: + img = fn(img) + return img, label + + return process + + +def image_folder(path): + valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp') + classes = [ + d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) + ] + classes.sort() + class_map = {cls: idx for idx, cls in enumerate(classes)} + samples = [] + for dir in sorted(class_map.keys()): + d = os.path.join(path, dir) + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + p = os.path.join(root, fname) + if os.path.splitext(p)[1].lower() in valid_ext: + samples.append((p, [class_map[dir]])) + return samples + + +class ImageNetDataset(Dataset): + def __init__(self, path, mode='train'): + self.samples = image_folder(path) + self.mode = mode + if self.mode == 'train': + self.transform = compose([ + cv2.imread, random_crop_resize, random_flip, normalize_permute + ]) + else: + self.transform = compose( + [cv2.imread, center_crop_resize, normalize_permute]) + + def __getitem__(self, idx): + + return self.transform(self.samples[idx]) + + def __len__(self): + return len(self.samples) diff --git a/image_classification/main.py b/image_classification/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb71e953ffc1637df37f9433c7abf7f1aa59154 --- /dev/null +++ b/image_classification/main.py @@ -0,0 +1,142 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import os +import sys +sys.path.append('../') + +import time +import math +import numpy as np +import models +import paddle.fluid as fluid + +from model import CrossEntropy, Input, set_device +from imagenet_dataset import ImageNetDataset +from distributed import DistributedBatchSampler +from paddle.fluid.dygraph.parallel import ParallelEnv +from metrics import Accuracy +from paddle.fluid.io import BatchSampler, DataLoader + + +def make_optimizer(step_per_epoch, parameter_list=None): + base_lr = FLAGS.lr + momentum = 0.9 + weight_decay = 1e-4 + + boundaries = [step_per_epoch * e for e in [30, 60, 90]] + values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] + learning_rate = fluid.layers.piecewise_decay( + boundaries=boundaries, values=values) + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate=learning_rate, + warmup_steps=5 * step_per_epoch, + start_lr=0., + end_lr=base_lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=momentum, + regularization=fluid.regularizer.L2Decay(weight_decay), + parameter_list=parameter_list) + return optimizer + + +def main(): + device = set_device(FLAGS.device) + fluid.enable_dygraph(device) if FLAGS.dynamic else None + + model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only) + + if FLAGS.resume is not None: + model.load(FLAGS.resume) + + inputs = [Input([None, 3, 224, 224], 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] + + train_dataset = ImageNetDataset( + os.path.join(FLAGS.data, 'train'), mode='train') + val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val') + + optim = make_optimizer( + np.ceil( + len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks), + parameter_list=model.parameters()) + + model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels) + + if FLAGS.eval_only: + model.evaluate( + val_dataset, + batch_size=FLAGS.batch_size, + num_workers=FLAGS.num_workers) + return + + output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch, + time.strftime('%Y-%m-%d-%H-%M', + time.localtime())) + if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir): + os.makedirs(output_dir) + + model.fit(train_dataset, + val_dataset, + batch_size=FLAGS.batch_size, + epochs=FLAGS.epoch, + save_dir=output_dir, + num_workers=FLAGS.num_workers) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Resnet Training on ImageNet") + parser.add_argument( + 'data', + metavar='DIR', + help='path to dataset ' + '(should have subdirectories named "train" and "val"') + parser.add_argument( + "--arch", type=str, default='resnet50', help="model name") + parser.add_argument( + "--device", type=str, default='gpu', help="device to run, cpu or gpu") + parser.add_argument( + "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "-e", "--epoch", default=120, type=int, help="number of epoch") + parser.add_argument( + '--lr', + '--learning-rate', + default=0.1, + type=float, + metavar='LR', + help='initial learning rate') + parser.add_argument( + "-b", "--batch-size", default=64, type=int, help="batch size") + parser.add_argument( + "-n", "--num-workers", default=4, type=int, help="dataloader workers") + parser.add_argument( + "--output-dir", type=str, default='output', help="save dir") + parser.add_argument( + "-r", + "--resume", + default=None, + type=str, + help="checkpoint path to resume") + parser.add_argument( + "--eval-only", action='store_true', help="enable dygraph mode") + FLAGS = parser.parse_args() + assert FLAGS.data, "error: must provide data path" + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7928da7df031c790203c18c1933414ef485a038d --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .resnet import * diff --git a/models/download.py b/models/download.py new file mode 100644 index 0000000000000000000000000000000000000000..10d3fba390647c494448b83295901a8973d2aba8 --- /dev/null +++ b/models/download.py @@ -0,0 +1,147 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import shutil +import requests +import tqdm +import hashlib +import time + +from paddle.fluid.dygraph.parallel import ParallelEnv + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['get_weights_path'] + +WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights") + +DOWNLOAD_RETRY_LIMIT = 3 + + +def get_weights_path(url, md5sum=None): + """Get weights path from WEIGHT_HOME, if not exists, + download it from url. + """ + path, _ = get_path(url, WEIGHTS_HOME, md5sum) + return path + + +def map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def get_path(url, root_dir, md5sum=None, check_exist=True): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + """ + # parse path after download to decompress under root_dir + fullpath = map_path(url, root_dir) + + exist_flag = False + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + exist_flag = True + if ParallelEnv().local_rank == 0: + logger.info("Found {}".format(fullpath)) + else: + if ParallelEnv().local_rank == 0: + fullpath = _download(url, root_dir, md5sum) + else: + while not os.path.exists(fullpath): + time.sleep(1) + return fullpath, exist_flag + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + if ParallelEnv().local_rank == 0: + logger.info("Downloading {} from {}".format(fname, url)) + + req = requests.get(url, stream=True) + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + for chunk in tqdm.tqdm( + req.iter_content(chunk_size=1024), + total=(int(total_size) + 1023) // 1024, + unit='KB'): + f.write(chunk) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + if ParallelEnv().local_rank == 0: + logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + if ParallelEnv().local_rank == 0: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..adf831c1b8778793a7cd86986db36bbe64f85c52 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,228 @@ +from __future__ import division +from __future__ import print_function + +import math +import paddle.fluid as fluid + +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.container import Sequential + +from model import Model +from .download import get_weights_path + +__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] + +model_urls = { + 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', + '0884c9087266496c41c60d14a96f8530') +} + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._batch_norm = BatchNorm(num_filters, act=act) + + def forward(self, inputs): + x = self._conv(inputs) + x = self._batch_norm(x) + + return x + + +class BasicBlock(fluid.dygraph.Layer): + expansion = 1 + + def __init__(self, num_channels, num_filters, stride, shortcut=True): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class BottleneckBlock(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters, stride, shortcut=True): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + self.conv2 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + self._num_channels_out = num_filters * 4 + + def forward(self, inputs): + x = self.conv0(inputs) + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + x = fluid.layers.elementwise_add(x=short, y=conv2) + + layer_helper = LayerHelper(self.full_name(), act='relu') + return layer_helper.append_activation(x) + # return fluid.layers.relu(x) + + +class ResNet(Model): + def __init__(self, Block, depth=50, num_classes=1000): + super(ResNet, self).__init__() + + layer_config = { + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + } + assert depth in layer_config.keys(), \ + "supported depth are {} but input layer is {}".format( + layer_config.keys(), depth) + + layers = layer_config[depth] + num_in = [64, 256, 512, 1024] + num_out = [64, 128, 256, 512] + + self.conv = ConvBNLayer( + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + self.pool = Pool2D( + pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + self.layers = [] + for idx, num_blocks in enumerate(layers): + blocks = [] + shortcut = False + for b in range(num_blocks): + block = Block( + num_channels=num_in[idx] if b == 0 else num_out[idx] * 4, + num_filters=num_out[idx], + stride=2 if b == 0 and idx != 0 else 1, + shortcut=shortcut) + blocks.append(block) + shortcut = True + layer = self.add_sublayer("layer_{}".format(idx), + Sequential(*blocks)) + self.layers.append(layer) + + self.global_pool = Pool2D( + pool_size=7, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(2048 * 1.0) + self.fc_input_dim = num_out[-1] * 4 * 1 * 1 + self.fc = Linear( + self.fc_input_dim, + num_classes, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + def forward(self, inputs): + x = self.conv(inputs) + x = self.pool(x) + for layer in self.layers: + x = layer(x) + x = self.global_pool(x) + x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim]) + x = self.fc(x) + return x + + +def _resnet(arch, Block, depth, pretrained): + model = ResNet(Block, depth) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path(model_urls[arch][0], + model_urls[arch][1]) + assert weight_path.endswith( + '.pdparams'), "suffix of weight must be .pdparams" + model.load(weight_path[:-9]) + return model + + +def resnet50(pretrained=False): + return _resnet('resnet50', BottleneckBlock, 50, pretrained) + + +def resnet101(pretrained=False): + return _resnet('resnet101', BottleneckBlock, 101, pretrained) + + +def resnet152(pretrained=False): + return _resnet('resnet152', BottleneckBlock, 152, pretrained)