提交 131410a0 编写于 作者: L LielinJiang

add models

上级 6d9e77b9
# 高级api图像分类
## 数据集准备
在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示:
```bash
/path/to/imagenet
train
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
val
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
```
## 训练
### 单卡训练
执行如下命令进行训练
```bash
python -u main.py --arch resnet50 /path/to/imagenet -d
```
### 多卡训练
执行如下命令进行训练
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet
```
## 预测
### 单卡预测
执行如下命令进行预测
```bash
python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet
```
### 多卡预测
执行如下命令进行多卡预测
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet
```
## 参数说明
* **arch**: 要训练或预测的模型名称
* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu'
* **dynamic**: 是否使用动态图模式训练
* **epoch**: 训练的轮数,默认值:120
* **learning-rate**: 学习率,默认值:0.1
* **batch-size**: 每张卡的batch size,默认值:64
* **output-dir**: 模型文件保存的文件夹,默认值:'output'
* **num-workers**: dataloader的进程数,默认值:4
* **resume**: 恢复训练的模型路径,默认值:None
* **eval-only**: 仅仅进行预测,默认值:False
## 模型
| 模型 | top1 acc | top5 acc |
| --- | --- | --- |
| ResNet50 | 76.28 | 93.04 |
import os
import cv2
import math
import random
import numpy as np
from paddle.fluid.io import Dataset
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i:i + c, j:j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i:i + h, j:j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
return img
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
def image_folder(path):
valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp')
classes = [
d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))
]
classes.sort()
class_map = {cls: idx for idx, cls in enumerate(classes)}
samples = []
for dir in sorted(class_map.keys()):
d = os.path.join(path, dir)
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
p = os.path.join(root, fname)
if os.path.splitext(p)[1].lower() in valid_ext:
samples.append((p, [class_map[dir]]))
return samples
class ImageNetDataset(Dataset):
def __init__(self, path, mode='train'):
self.samples = image_folder(path)
self.mode = mode
if self.mode == 'train':
self.transform = compose([
cv2.imread, random_crop_resize, random_flip, normalize_permute
])
else:
self.transform = compose(
[cv2.imread, center_crop_resize, normalize_permute])
def __getitem__(self, idx):
return self.transform(self.samples[idx])
def __len__(self):
return len(self.samples)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import sys
sys.path.append('../')
import time
import math
import numpy as np
import models
import paddle.fluid as fluid
from model import CrossEntropy, Input, set_device
from imagenet_dataset import ImageNetDataset
from distributed import DistributedBatchSampler
from paddle.fluid.dygraph.parallel import ParallelEnv
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr
momentum = 0.9
weight_decay = 1e-4
boundaries = [step_per_epoch * e for e in [30, 60, 90]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
train_dataset = ImageNetDataset(
os.path.join(FLAGS.data, 'train'), mode='train')
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
optim = make_optimizer(
np.ceil(
len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
if FLAGS.eval_only:
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch,
time.strftime('%Y-%m-%d-%H-%M',
time.localtime()))
if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir):
os.makedirs(output_dir)
model.fit(train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument(
'data',
metavar='DIR',
help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"--arch", type=str, default='resnet50', help="model name")
parser.add_argument(
"--device", type=str, default='gpu', help="device to run, cpu or gpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=120, type=int, help="number of epoch")
parser.add_argument(
'--lr',
'--learning-rate',
default=0.1,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"-n", "--num-workers", default=4, type=int, help="dataloader workers")
parser.add_argument(
"--output-dir", type=str, default='output', help="save dir")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
from .resnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import shutil
import requests
import tqdm
import hashlib
import time
from paddle.fluid.dygraph.parallel import ParallelEnv
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def get_weights_path(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
"""
path, _ = get_path(url, WEIGHTS_HOME, md5sum)
return path
def map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
exist_flag = True
if ParallelEnv().local_rank == 0:
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if ParallelEnv().local_rank == 0:
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model
from .download import get_weights_path
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'0884c9087266496c41c60d14a96f8530')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BasicBlock(fluid.dygraph.Layer):
expansion = 1
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(x)
# return fluid.layers.relu(x)
class ResNet(Model):
def __init__(self, Block, depth=50, num_classes=1000):
super(ResNet, self).__init__()
layer_config = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
block = Block(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
num_filters=num_out[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer("layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def resnet50(pretrained=False):
return _resnet('resnet50', BottleneckBlock, 50, pretrained)
def resnet101(pretrained=False):
return _resnet('resnet101', BottleneckBlock, 101, pretrained)
def resnet152(pretrained=False):
return _resnet('resnet152', BottleneckBlock, 152, pretrained)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册