提交 e7df41bb 编写于 作者: R ruri 提交者: qingqing01

Update image classification (#1499)

* Add python-cv2 reader for image classification.
* Update README_cn.md
上级 8c32619e
......@@ -84,7 +84,7 @@ python train.py \
Or can start the training step by running the ```run.sh```.
**data reader introduction:** Data reader is defined in ```reader.py```. In [training stage](#training-a-model), random crop and flipping are used, while center crop is used in [evaluation](#inference) and [inference](#inference) stages. Supported data augmentation includes:
**data reader introduction:** Data reader is defined in ```reader.py``` and ```reader_cv2.py```, Using CV2 reader can improve the speed of reading. In [training stage](#training-a-model), random crop and flipping are used, while center crop is used in [evaluation](#inference) and [inference](#inference) stages. Supported data augmentation includes:
* rotation
* color jitter
* random crop
......@@ -190,19 +190,20 @@ Models consists of two categories: Models with specified parameters names in mod
Models are trained by starting with learning rate ```0.1``` and decaying it by ```0.1``` after each pre-defined epoches, if not special introduced. Available top-1/top-5 validation accuracy on ImageNet 2012 are listed in table. Pretrained models can be downloaded by clicking related model names.
- Released models: specify parameter names
|model | top-1/top-5 accuracy
|- | -:
|[AlexNet](http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.zip) | 56.34%/79.02%
|[VGG11](http://paddle-imagenet-models-name.bj.bcebos.com/VGG11_pretained.zip) | 68.86%/88.64%
|[MobileNetV1](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.zip) | 70.7%/89.41%
|[ResNet50](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.zip) | 76.46%/93.04%
|[ResNet101](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.zip) | 77.65%/93.71%
|model | top-1/top-5 accuracy(PIL)| top-1/top-5 accuracy(CV2) |
|- |:-: |:-:|
|[AlexNet](http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.zip) | 56.71%/79.18% | 55.88%/78.65% |
|[VGG11](http://paddle-imagenet-models-name.bj.bcebos.com/VGG11_pretained.zip) | 68.92%/88.66% | 68.61%/88.60% |
|[MobileNetV1](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.zip) | 70.91%/89.54% | 70.51%/89.35% |
|[ResNet50](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.zip) | 76.35%/92.80% | 76.22%/92.92% |
|[ResNet101](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.zip) | 77.49%/93.57% | 77.56%/93.64% |
- Released models: not specify parameter names
|model | top-1/top-5 accuracy
|- | -:
|[ResNet152](http://paddle-imagenet-models.bj.bcebos.com/ResNet152_pretrained.zip) | 78.29%/94.11%
|[SE_ResNeXt50_32x4d](http://paddle-imagenet-models.bj.bcebos.com/se_resnext_50_model.tar) | 78.33%/93.96%
|model | top-1/top-5 accuracy(PIL)| top-1/top-5 accuracy(CV2) |
|- |:-: |:-:|
|[ResNet152](http://paddle-imagenet-models.bj.bcebos.com/ResNet152_pretrained.zip) | 78.18%/93.93% | 78.11%/94.04% |
|[SE_ResNeXt50_32x4d](http://paddle-imagenet-models.bj.bcebos.com/se_resnext_50_model.tar) | 78.32%/93.96% | 77.58%/93.73% |
......@@ -58,7 +58,7 @@ python train.py \
--model=SE_ResNeXt50_32x4d \
--batch_size=32 \
--total_images=1281167 \
--class_dim=1000
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--with_mem_opt=False \
......@@ -79,8 +79,9 @@ python train.py \
* **lr**: initialized learning rate. Default: 0.1.
* **pretrained_model**: model path for pretraining. Default: None.
* **checkpoint**: the checkpoint path to resume. Default: None.
* **model_category**: the category of models, ("models"|"models_name"). Default:"models".
**数据读取器说明:** 数据读取器定义在```reader.py```中。[训练阶段](#training-a-model), 默认采用的增广方式是随机裁剪与水平翻转, 而在[评估](#inference)[推断](#inference)阶段用的默认方式是中心裁剪。当前支持的数据增广方式有:
**数据读取器说明:** 数据读取器定义在```reader.py``````reader_cv2.py```中, 一般, CV2 reader可以提高数据读取速度, reader(PIL)可以得到相对更高的精度, [训练阶段](#training-a-model), 默认采用的增广方式是随机裁剪与水平翻转, 而在[评估](#inference)[推断](#inference)阶段用的默认方式是中心裁剪。当前支持的数据增广方式有:
* 旋转
* 颜色抖动
* 随机裁剪
......@@ -183,27 +184,24 @@ Test-12-score: [15.040644], class [386]
```
## 已有模型及其性能
Models包括两种模型:带有参数名字的模型,和不带有参数名字的模型。通过设置 ```model_category = models_name```来训练带有参数名字的模型。
表格中列出了在"models"目录下支持的神经网络种类,并且给出了已完成训练的模型在ImageNet-2012验证集合上的top-1/top-5精度;如无特征说明,训练模型的初始学习率为```0.1```,每隔预定的epochs会下降```0.1```。预训练模型可以通过点击相应模型的名称进行下载。
|model | top-1/top-5 accuracy
|- | -:
|[AlexNet](http://paddle-imagenet-models.bj.bcebos.com/alexnet_model.tar) | 57.21%/79.72%
|VGG11 | -
|VGG13 | -
|VGG16 | -
|VGG19 | -
|GoogleNet | -
|InceptionV4 | -
|MobileNet | -
|[ResNet50](http://paddle-imagenet-models.bj.bcebos.com/resnet_50_model.tar) | 76.63%/93.10%
|ResNet101 | -
|ResNet152 | -
|[SE_ResNeXt50_32x4d](http://paddle-imagenet-models.bj.bcebos.com/se_resnext_50_model.tar) | 78.33%/93.96%
|SE_ResNeXt101_32x4d | -
|SE_ResNeXt152_32x4d | -
|DPN68 | -
|DPN92 | -
|DPN98 | -
|DPN107 | -
|DPN131 | -
- Released models: specify parameter names
|model | top-1/top-5 accuracy(PIL)| top-1/top-5 accuracy(CV2) |
|- |:-: |:-:|
|[AlexNet](http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.zip) | 56.71%/79.18% | 55.88%/78.65% |
|[VGG11](http://paddle-imagenet-models-name.bj.bcebos.com/VGG11_pretained.zip) | 68.92%/88.66% | 68.61%/88.60% |
|[MobileNetV1](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.zip) | 70.91%/89.54% | 70.51%/89.35% |
|[ResNet50](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.zip) | 76.35%/92.80% | 76.22%/92.92% |
|[ResNet101](http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.zip) | 77.49%/93.57% | 77.56%/93.64% |
- Released models: not specify parameter names
|model | top-1/top-5 accuracy(PIL)| top-1/top-5 accuracy(CV2) |
|- |:-: |:-:|
|[ResNet152](http://paddle-imagenet-models.bj.bcebos.com/ResNet152_pretrained.zip) | 78.18%/93.93% | 78.11%/94.04% |
|[SE_ResNeXt50_32x4d](http://paddle-imagenet-models.bj.bcebos.com/se_resnext_50_model.tar) | 78.32%/93.96% | 77.58%/93.73% |
......@@ -8,7 +8,8 @@ import sys
import paddle
import paddle.fluid as fluid
import models
import reader
#import reader_cv2 as reader
import reader as reader
import argparse
import functools
from models.learning_rate import cosine_decay
......@@ -83,7 +84,7 @@ def eval(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
val_reader = paddle.batch(reader.val(), batch_size=args.batch_size)
val_reader = paddle.batch(reader.val(""), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
......
import os
import math
import random
import functools
import numpy as np
import paddle
import cv2
import io
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img):
""" rotate_image """
(h, w) = img.shape[:2]
center = (w / 2, h / 2)
angle = np.random.randint(-10, 11)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
def random_crop(img, size, scale=None, ratio=None):
""" random_crop """
scale = [0.08, 1.0] if scale is None else scale
ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.shape[1]) / img.shape[0]) / (w**2),
(float(img.shape[0]) / img.shape[1]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img[i:i + h, j:j + w, :]
resized = cv2.resize(img, (size, size))
return resized
def distort_color(img):
return img
def resize_short(img, target_size):
""" resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(img, target_size, center):
""" crop_image """
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def process_image(sample,
mode,
color_jitter,
rotate,
crop_size=224,
mean=None,
std=None):
""" process_image """
mean = [0.485, 0.456, 0.406] if mean is None else mean
std = [0.229, 0.224, 0.225] if std is None else std
img_path = sample[0]
print('&' * 80)
print(img_path)
img = cv2.imread(img_path)
if mode == 'train':
if rotate:
img = rotate_image(img)
if crop_size > 0:
img = random_crop(img, crop_size)
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
else:
if crop_size > 0:
img = resize_short(img, crop_size)
img = crop_image(img, target_size=crop_size, center=True)
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return (img, sample[1])
elif mode == 'test':
return (img, )
def image_mapper(**kwargs):
""" image_mapper """
return functools.partial(process_image, **kwargs)
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(DATA_DIR, line)
yield [img_path]
image_mapper = functools.partial(
process_image,
mode=mode,
color_jitter=color_jitter,
rotate=color_jitter,
crop_size=224)
reader = paddle.reader.xmap_readers(
image_mapper, reader, THREAD, BUF_SIZE, order=False)
return reader
def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册