提交 33d384b4 编写于 作者: L Leon 提交者: qingqing01

Add model fo WebVision 2018 (#2775)

* Add a model for WebVision 2018.
上级 2df26c04
# WebVision Image Classification 2018 Challenge
The goal of this challenge is to advance the area of learning knowledge and representation from web data. The web data not only contains huge numbers of visual images, but also rich meta information concerning these visual data, which could be exploited to learn good representations and models.
More detail [[WebVision2018](https://www.vision.ee.ethz.ch/webvision/challenge.html)].
By observing the web data, we find that there are five key challenges, i.e., imbalanced class sizes, high intra-classes diversity and inter-class similarity, imprecise instances,
insufficient representative instances, and ambiguous class labels. To alleviate these challenges, we assume that every training instance has
the potential to contribute positively by alleviating the data bias and noise via reweighting the influence of each instance according to different
class sizes, large instance clusters, its confidence, small instance bags and the labels. In this manner, the influence of bias and noise in the
web data can be gradually alleviated, leading to the steadily improving performance of URNet. Experimental results in the WebVision 2018
challenge with 16 million noisy training images from 5000 classes show that our approach outperforms state-of-the-art models and ranks the first
place in the image classification task. The detail of our solution can refer to our paper[[URNet](https://arxiv.org/abs/1811.00700)].
## 1.Prepare data
We have provided a download + preprocess script of valset data.
```
cd data
sh download_webvision2018.sh
```
Note that the server hosting Webvision Data reboots every day at midnight (Zurich time). You might want to change wget to something else.
## 2.Environment installation
Cudnn >= 7, CUDA 8/9, PaddlePaddle version >= 1.3, python version 2.7 (More detail [[PaddlePaddle](https://github.com/paddlepaddle/paddle)])
## 3.Download pretrained model
| Model | Acc@1 | Acc@5
| - | - | -
| [ResNeXt101_32x4d](https://paddlemodels.bj.bcebos.com/webvision/ResNeXt101_32x4d_Released.tar.gz) | 53.4% | 77.1%
## 4.Test image
```
sh run.sh
```
or
```
export CUDA_VISIBLE_DEVICES=$GPU_ID
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python infer.py --model ResNeXt101_32x4d \
--pretrained_model $PRETRAINEDMODELPATH \
--class_dim 5000 \
--img_path $IMGPATH \
--img_list $IMGLIST \
--use_gpu True
```
You will get the predictions of images.
## 5.Evaluation
```
export CUDA_VISIBLE_DEVICES=$GPU_ID
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python eval.py --model ResNeXt101_32x4d \
--pretrained_model $PRETRAINEDMODELPATH \
--class_dim 5000 \
--img_path $IMGPATH \
--img_list $IMGLIST \
--use_gpu True
```
You will get the Acc@1 and Acc@5.
wget https://data.vision.ee.ethz.ch/cvl/webvision2018/val_images_resized.tar
tar -xvf val_images_resized.tar
rm val_images_resized.tar
wget https://data.vision.ee.ethz.ch/cvl/webvision2018/val_filelist.txt
mv val_images_resized val
mv val_filelist.txt val_list.txt
import os
import numpy as np
import time
import sys
import paddle
import paddle.fluid as fluid
import models
import reader
import argparse
import functools
from utils import add_arguments, print_arguments, accuracy
import math
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 5000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "ResNeXt101_32x4d", "Set the network to use.")
add_arg('img_list', str, "None", "list of valset.")
add_arg('img_path', str, "NOne", "path of valset.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def eval(args):
# parameters from arguments
class_dim = args.class_dim
model_name = args.model
pretrained_model = args.pretrained_model
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
# model definition
model = models.__dict__[model_name]()
if model_name is "GoogleNet":
out, _, _ = model.net(input=image, class_dim=class_dim)
else:
out = model.net(input=image, class_dim=class_dim)
test_program = fluid.default_main_program().clone(for_test=True)
fetch_list = [out.name]
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
test_batch_size = args.batch_size
img_size = image_shape[1]
test_reader = paddle.batch(reader.test(args, img_size), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
targets = []
with open(args.img_list, 'r') as f:
for line in f.readlines():
targets.append(line.strip().split()[-1])
targets = np.array(targets, dtype=np.int)
preds = []
TOPK = 5
for batch_id, data in enumerate(test_reader()):
all_result = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
pred_label = np.argsort(-all_result[0], 1)[:,:5]
print("Test-{0}".format(batch_id))
preds.append(pred_label)
preds = np.vstack(preds)
top1, top5 = accuracy(targets, preds)
print("top1:{:.4f} top5:{:.4f}".format(top1,top5))
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import sys
import math
import numpy as np
import argparse
import functools
import paddle
import paddle.fluid as fluid
import reader
import models
import utils
from utils.utility import add_arguments,print_arguments
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 5000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "ResNeXt101_32x4d", "Set the network to use.")
add_arg('save_inference', bool, False, "Whether to save inference model or not")
add_arg('resize_short_size', int, 256, "Set resize short size")
add_arg('img_list', str, None, "list of valset")
add_arg('img_path', str, None, "path of valset")
# yapf: enable
def infer(args):
# parameters from arguments
class_dim = args.class_dim
model_name = args.model
save_inference = args.save_inference
pretrained_model = args.pretrained_model
image_shape = [int(m) for m in args.image_shape.split(",")]
model_list = [m for m in dir(models) if "__" not in m]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
# model definition
model = models.__dict__[model_name]()
if model_name == "GoogleNet":
out, _, _ = model.net(input=image, class_dim=class_dim)
else:
out = model.net(input=image, class_dim=class_dim)
test_program = fluid.default_main_program().clone(for_test=True)
fetch_list = [out.name]
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, pretrained_model)
if save_inference:
fluid.io.save_inference_model(
dirname=model_name,
feeded_var_names=['image'],
main_program=test_program,
target_vars=out,
executor=exe,
model_filename='model',
params_filename='params')
print("model: ",model_name," is already saved")
exit(0)
test_batch_size = 1
img_size = image_shape[1]
test_reader = paddle.batch(reader.test(args, img_size), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = 1
for batch_id, data in enumerate(test_reader()):
result = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
result = result[0][0]
pred_label = np.argsort(result)[::-1][:TOPK]
print("Test-{0}-score: {1}, class {2}"
.format(batch_id, result[pred_label], pred_label))
sys.stdout.flush()
def main():
args = parser.parse_args()
print_arguments(args)
infer(args)
if __name__ == '__main__':
main()
from .resnext_32x4d import ResNeXt50_32x4d, ResNeXt101_32x4d, ResNeXt152_32x4d
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNeXt", "ResNeXt50_32x4d", "ResNeXt101_32x4d", "ResNeXt152_32x4d"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNeXt():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
cardinality = 32
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="res_conv1")
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'),
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, cardinality, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input, num_filters, stride, name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
def ResNeXt50_32x4d():
model = ResNeXt(layers=50)
return model
def ResNeXt101_32x4d():
model = ResNeXt(layers=101)
return model
def ResNeXt152_32x4d():
model = ResNeXt(layers=152)
return model
import os
import math
import random
import functools
import numpy as np
import paddle
import cv2
import io
random.seed(0)
np.random.seed(0)
THREAD = 8
BUF_SIZE = 128
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img):
""" rotate_image """
(h, w) = img.shape[:2]
center = (w / 2, h / 2)
angle = np.random.randint(-10, 11)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
def random_crop(img, size, scale=None, ratio=None):
""" random_crop """
scale = [0.08, 1.0] if scale is None else scale
ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.shape[1]) / img.shape[0]) / (w ** 2),
(float(img.shape[0]) / img.shape[1]) / (h ** 2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img[i:i+h, j:j+w, :]
resized = cv2.resize(img, (size, size),
interpolation=cv2.INTER_CUBIC
)
return resized
def distort_color(img):
return img
def resize_short(img, target_size):
""" resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(img, (resized_width, resized_height),
interpolation=cv2.INTER_CUBIC
)
return resized
def crop_image(img, target_size, center):
""" crop_image """
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def process_image(sample, mode, color_jitter, rotate,
crop_size=224, mean=None, std=None):
""" process_image """
mean = [0.485, 0.456, 0.406] if mean is None else mean
std = [0.229, 0.224, 0.225] if std is None else std
img_path = sample[0]
img = cv2.imread(img_path)
img = cv2.resize(img, (crop_size, crop_size))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
return (img, )
def image_mapper(**kwargs):
""" image_mapper """
return functools.partial(process_image, **kwargs)
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=None,
crop_size=224):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(lines)
lines = full_lines
for line in lines:
img_path, label = line.strip().split()
img_path = os.path.join(data_dir, img_path)
yield [img_path]
image_mapper = functools.partial(process_image,
mode=mode, color_jitter=color_jitter, rotate=rotate, crop_size=crop_size)
reader = paddle.reader.xmap_readers(
image_mapper, reader, THREAD, BUF_SIZE, order=True)
return reader
def create_img_reader(args):
def reader():
img_path = args.img_path
yield [img_path]
return reader
def test(settings, crop_size):
file_list = settings.img_list
data_dir = settings.img_path
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir, crop_size=crop_size)
export CUDA_VISIBLE_DEVICES=0
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python infer.py \
--model ResNeXt101_32x4d \
--class_dim 5000 \
--pretrained ./ckpt/ResNeXt101_32x4d_Release/ \
--img_list ./data/val_list.txt \
--img_path ./data/val/ \
--use_gpu True
from .utility import add_arguments, print_arguments
from .class_accuracy import accuracy
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
def accuracy(targets, preds):
"""Get the class-level top1 and top5 of model.
Usage:
.. code-blcok::python
top1, top5 = accuracy(targets, preds)
:params args: evaluate the prediction of model.
:type args: numpy.array
"""
top1 = np.zeros((5000,), dtype=np.float32)
top5 = np.zeros((5000,), dtype=np.float32)
count = np.zeros((5000,), dtype=np.float32)
for index in range(targets.shape[0]):
target = targets[index]
if target == preds[index,0]:
top1[target] += 1
top5[target] += 1
elif np.sum(target == preds[index,:5]):
top5[target] += 1
count[target] += 1
return (top1/(count+1e-12)).mean(), (top5/(count+1e-12)).mean()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
import six
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("------------- Configuration Arguments -------------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%25s : %s" % (arg, value))
print("----------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册