diff --git a/PaddleCV/Research/webvision2018/README.md b/PaddleCV/Research/webvision2018/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7c1818f4a1694ac4d2890a675f65f6fb5fc2c800 --- /dev/null +++ b/PaddleCV/Research/webvision2018/README.md @@ -0,0 +1,59 @@ +# WebVision Image Classification 2018 Challenge +The goal of this challenge is to advance the area of learning knowledge and representation from web data. The web data not only contains huge numbers of visual images, but also rich meta information concerning these visual data, which could be exploited to learn good representations and models. +More detail [[WebVision2018](https://www.vision.ee.ethz.ch/webvision/challenge.html)]. + +By observing the web data, we find that there are five key challenges, i.e., imbalanced class sizes, high intra-classes diversity and inter-class similarity, imprecise instances, +insufficient representative instances, and ambiguous class labels. To alleviate these challenges, we assume that every training instance has +the potential to contribute positively by alleviating the data bias and noise via reweighting the influence of each instance according to different +class sizes, large instance clusters, its confidence, small instance bags and the labels. In this manner, the influence of bias and noise in the +web data can be gradually alleviated, leading to the steadily improving performance of URNet. Experimental results in the WebVision 2018 +challenge with 16 million noisy training images from 5000 classes show that our approach outperforms state-of-the-art models and ranks the first +place in the image classification task. The detail of our solution can refer to our paper[[URNet](https://arxiv.org/abs/1811.00700)]. + +## 1.Prepare data +We have provided a download + preprocess script of valset data. +``` +cd data +sh download_webvision2018.sh +``` +Note that the server hosting Webvision Data reboots every day at midnight (Zurich time). You might want to change wget to something else. + +## 2.Environment installation +Cudnn >= 7, CUDA 8/9, PaddlePaddle version >= 1.3, python version 2.7 (More detail [[PaddlePaddle](https://github.com/paddlepaddle/paddle)]) + +## 3.Download pretrained model +| Model | Acc@1 | Acc@5 +| - | - | - +| [ResNeXt101_32x4d](https://paddlemodels.bj.bcebos.com/webvision/ResNeXt101_32x4d_Released.tar.gz) | 53.4% | 77.1% + +## 4.Test image +``` +sh run.sh +``` +or +``` +export CUDA_VISIBLE_DEVICES=$GPU_ID +export FLAGS_fraction_of_gpu_memory_to_use=1.0 +python infer.py --model ResNeXt101_32x4d \ + --pretrained_model $PRETRAINEDMODELPATH \ + --class_dim 5000 \ + --img_path $IMGPATH \ + --img_list $IMGLIST \ + --use_gpu True +``` + +You will get the predictions of images. +## 5.Evaluation +``` +export CUDA_VISIBLE_DEVICES=$GPU_ID +export FLAGS_fraction_of_gpu_memory_to_use=1.0 +python eval.py --model ResNeXt101_32x4d \ + --pretrained_model $PRETRAINEDMODELPATH \ + --class_dim 5000 \ + --img_path $IMGPATH \ + --img_list $IMGLIST \ + --use_gpu True + +``` +You will get the Acc@1 and Acc@5. + diff --git a/PaddleCV/Research/webvision2018/data/download_webvision2018.sh b/PaddleCV/Research/webvision2018/data/download_webvision2018.sh new file mode 100644 index 0000000000000000000000000000000000000000..fa09f78c6790b255fd303962e2762cc03616fe23 --- /dev/null +++ b/PaddleCV/Research/webvision2018/data/download_webvision2018.sh @@ -0,0 +1,6 @@ +wget https://data.vision.ee.ethz.ch/cvl/webvision2018/val_images_resized.tar +tar -xvf val_images_resized.tar +rm val_images_resized.tar +wget https://data.vision.ee.ethz.ch/cvl/webvision2018/val_filelist.txt +mv val_images_resized val +mv val_filelist.txt val_list.txt diff --git a/PaddleCV/Research/webvision2018/eval.py b/PaddleCV/Research/webvision2018/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ced76de93c69e05a8e001aaa011bac3a352378fa --- /dev/null +++ b/PaddleCV/Research/webvision2018/eval.py @@ -0,0 +1,102 @@ +import os +import numpy as np +import time +import sys +import paddle +import paddle.fluid as fluid +import models +import reader +import argparse +import functools +from utils import add_arguments, print_arguments, accuracy +import math +import sys +reload(sys) +sys.setdefaultencoding('utf-8') + +parser = argparse.ArgumentParser(description=__doc__) +# yapf: disable +add_arg = functools.partial(add_arguments, argparser=parser) +add_arg('batch_size', int, 32, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('class_dim', int, 5000, "Class number.") +add_arg('image_shape', str, "3,224,224", "Input image size") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('model', str, "ResNeXt101_32x4d", "Set the network to use.") +add_arg('img_list', str, "None", "list of valset.") +add_arg('img_path', str, "NOne", "path of valset.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def eval(args): + # parameters from arguments + class_dim = args.class_dim + model_name = args.model + pretrained_model = args.pretrained_model + image_shape = [int(m) for m in args.image_shape.split(",")] + + assert model_name in model_list, "{} is not in lists: {}".format(args.model, + model_list) + + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + + # model definition + model = models.__dict__[model_name]() + + if model_name is "GoogleNet": + out, _, _ = model.net(input=image, class_dim=class_dim) + else: + out = model.net(input=image, class_dim=class_dim) + + test_program = fluid.default_main_program().clone(for_test=True) + + fetch_list = [out.name] + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(pretrained_model, var.name)) + + fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) + + + test_batch_size = args.batch_size + + img_size = image_shape[1] + test_reader = paddle.batch(reader.test(args, img_size), batch_size=test_batch_size) + feeder = fluid.DataFeeder(place=place, feed_list=[image]) + + targets = [] + with open(args.img_list, 'r') as f: + for line in f.readlines(): + targets.append(line.strip().split()[-1]) + targets = np.array(targets, dtype=np.int) + + preds = [] + TOPK = 5 + + for batch_id, data in enumerate(test_reader()): + all_result = exe.run(test_program, + fetch_list=fetch_list, + feed=feeder.feed(data)) + pred_label = np.argsort(-all_result[0], 1)[:,:5] + print("Test-{0}".format(batch_id)) + preds.append(pred_label) + preds = np.vstack(preds) + top1, top5 = accuracy(targets, preds) + print("top1:{:.4f} top5:{:.4f}".format(top1,top5)) + +def main(): + args = parser.parse_args() + print_arguments(args) + eval(args) + + +if __name__ == '__main__': + main() diff --git a/PaddleCV/Research/webvision2018/infer.py b/PaddleCV/Research/webvision2018/infer.py new file mode 100755 index 0000000000000000000000000000000000000000..ab2be147ac466a37003902dff6717c3063ab57e6 --- /dev/null +++ b/PaddleCV/Research/webvision2018/infer.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import sys +import math +import numpy as np +import argparse +import functools + +import paddle +import paddle.fluid as fluid +import reader +import models +import utils +from utils.utility import add_arguments,print_arguments + + +parser = argparse.ArgumentParser(description=__doc__) +# yapf: disable +add_arg = functools.partial(add_arguments, argparser=parser) +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('class_dim', int, 5000, "Class number.") +add_arg('image_shape', str, "3,224,224", "Input image size") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('model', str, "ResNeXt101_32x4d", "Set the network to use.") +add_arg('save_inference', bool, False, "Whether to save inference model or not") +add_arg('resize_short_size', int, 256, "Set resize short size") +add_arg('img_list', str, None, "list of valset") +add_arg('img_path', str, None, "path of valset") +# yapf: enable + +def infer(args): + # parameters from arguments + class_dim = args.class_dim + model_name = args.model + save_inference = args.save_inference + pretrained_model = args.pretrained_model + image_shape = [int(m) for m in args.image_shape.split(",")] + model_list = [m for m in dir(models) if "__" not in m] + assert model_name in model_list, "{} is not in lists: {}".format(args.model, + model_list) + + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + + # model definition + model = models.__dict__[model_name]() + if model_name == "GoogleNet": + out, _, _ = model.net(input=image, class_dim=class_dim) + else: + out = model.net(input=image, class_dim=class_dim) + + test_program = fluid.default_main_program().clone(for_test=True) + + fetch_list = [out.name] + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + fluid.io.load_persistables(exe, pretrained_model) + if save_inference: + fluid.io.save_inference_model( + dirname=model_name, + feeded_var_names=['image'], + main_program=test_program, + target_vars=out, + executor=exe, + model_filename='model', + params_filename='params') + print("model: ",model_name," is already saved") + exit(0) + test_batch_size = 1 + img_size = image_shape[1] + test_reader = paddle.batch(reader.test(args, img_size), batch_size=test_batch_size) + feeder = fluid.DataFeeder(place=place, feed_list=[image]) + + TOPK = 1 + for batch_id, data in enumerate(test_reader()): + result = exe.run(test_program, + fetch_list=fetch_list, + feed=feeder.feed(data)) + + result = result[0][0] + pred_label = np.argsort(result)[::-1][:TOPK] + print("Test-{0}-score: {1}, class {2}" + .format(batch_id, result[pred_label], pred_label)) + sys.stdout.flush() + + +def main(): + args = parser.parse_args() + print_arguments(args) + infer(args) + + +if __name__ == '__main__': + main() diff --git a/PaddleCV/Research/webvision2018/models/__init__.py b/PaddleCV/Research/webvision2018/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56e5f333d0d218cfa8401331cce3d0be21701dff --- /dev/null +++ b/PaddleCV/Research/webvision2018/models/__init__.py @@ -0,0 +1 @@ +from .resnext_32x4d import ResNeXt50_32x4d, ResNeXt101_32x4d, ResNeXt152_32x4d diff --git a/PaddleCV/Research/webvision2018/models/resnext_32x4d.py b/PaddleCV/Research/webvision2018/models/resnext_32x4d.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6b85f63d4ff510ab93dfd1c908258c10107881 --- /dev/null +++ b/PaddleCV/Research/webvision2018/models/resnext_32x4d.py @@ -0,0 +1,165 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["ResNeXt", "ResNeXt50_32x4d", "ResNeXt101_32x4d", "ResNeXt152_32x4d"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNeXt(): + def __init__(self, layers=50): + self.params = train_parameters + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [256, 512, 1024, 2048] + cardinality = 32 + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="res_conv1") + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'), + bias_attr=fluid.param_attr.ParamAttr(name='fc_offset')) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, cardinality, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, num_filters, stride, name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + +def ResNeXt50_32x4d(): + model = ResNeXt(layers=50) + return model + + +def ResNeXt101_32x4d(): + model = ResNeXt(layers=101) + return model + + +def ResNeXt152_32x4d(): + model = ResNeXt(layers=152) + return model diff --git a/PaddleCV/Research/webvision2018/reader.py b/PaddleCV/Research/webvision2018/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..9c424c97e0f3b128e03c72ef0a9ba43ce8078023 --- /dev/null +++ b/PaddleCV/Research/webvision2018/reader.py @@ -0,0 +1,144 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +import cv2 +import io + +random.seed(0) +np.random.seed(0) + +THREAD = 8 +BUF_SIZE = 128 + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + +def rotate_image(img): + """ rotate_image """ + (h, w) = img.shape[:2] + center = (w / 2, h / 2) + angle = np.random.randint(-10, 11) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(img, M, (w, h)) + return rotated + +def random_crop(img, size, scale=None, ratio=None): + """ random_crop """ + scale = [0.08, 1.0] if scale is None else scale + ratio = [3. / 4., 4. / 3.] if ratio is None else ratio + + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.shape[1]) / img.shape[0]) / (w ** 2), + (float(img.shape[0]) / img.shape[1]) / (h ** 2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img[i:i+h, j:j+w, :] + resized = cv2.resize(img, (size, size), + interpolation=cv2.INTER_CUBIC + ) + return resized + +def distort_color(img): + return img + +def resize_short(img, target_size): + """ resize_short """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + resized = cv2.resize(img, (resized_width, resized_height), + interpolation=cv2.INTER_CUBIC + ) + return resized + +def crop_image(img, target_size, center): + """ crop_image """ + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + +def process_image(sample, mode, color_jitter, rotate, + crop_size=224, mean=None, std=None): + """ process_image """ + + mean = [0.485, 0.456, 0.406] if mean is None else mean + std = [0.229, 0.224, 0.225] if std is None else std + + + img_path = sample[0] + img = cv2.imread(img_path) + img = cv2.resize(img, (crop_size, crop_size)) + + img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + return (img, ) + +def image_mapper(**kwargs): + """ image_mapper """ + return functools.partial(process_image, **kwargs) + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=None, + crop_size=224): + def reader(): + + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(lines) + lines = full_lines + for line in lines: + img_path, label = line.strip().split() + img_path = os.path.join(data_dir, img_path) + yield [img_path] + + + image_mapper = functools.partial(process_image, + mode=mode, color_jitter=color_jitter, rotate=rotate, crop_size=crop_size) + reader = paddle.reader.xmap_readers( + image_mapper, reader, THREAD, BUF_SIZE, order=True) + return reader + +def create_img_reader(args): + def reader(): + img_path = args.img_path + yield [img_path] + return reader + +def test(settings, crop_size): + file_list = settings.img_list + data_dir = settings.img_path + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir, crop_size=crop_size) diff --git a/PaddleCV/Research/webvision2018/run.sh b/PaddleCV/Research/webvision2018/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..06b1d77925c4d1118f30115aa1d3e4461a586825 --- /dev/null +++ b/PaddleCV/Research/webvision2018/run.sh @@ -0,0 +1,10 @@ +export CUDA_VISIBLE_DEVICES=0 +export FLAGS_fraction_of_gpu_memory_to_use=1.0 +python infer.py \ + --model ResNeXt101_32x4d \ + --class_dim 5000 \ + --pretrained ./ckpt/ResNeXt101_32x4d_Release/ \ + --img_list ./data/val_list.txt \ + --img_path ./data/val/ \ + --use_gpu True + diff --git a/PaddleCV/Research/webvision2018/utils/__init__.py b/PaddleCV/Research/webvision2018/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3900e439e072a35d3df7a6f6e486c0402c0f4e4f --- /dev/null +++ b/PaddleCV/Research/webvision2018/utils/__init__.py @@ -0,0 +1,2 @@ +from .utility import add_arguments, print_arguments +from .class_accuracy import accuracy diff --git a/PaddleCV/Research/webvision2018/utils/class_accuracy.py b/PaddleCV/Research/webvision2018/utils/class_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..30c1b824ab187cee462a6e5dbecaab0e7acf3dd3 --- /dev/null +++ b/PaddleCV/Research/webvision2018/utils/class_accuracy.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import numpy as np + +def accuracy(targets, preds): + """Get the class-level top1 and top5 of model. + + Usage: + + .. code-blcok::python + + top1, top5 = accuracy(targets, preds) + + :params args: evaluate the prediction of model. + :type args: numpy.array + + """ + top1 = np.zeros((5000,), dtype=np.float32) + top5 = np.zeros((5000,), dtype=np.float32) + count = np.zeros((5000,), dtype=np.float32) + + for index in range(targets.shape[0]): + target = targets[index] + if target == preds[index,0]: + top1[target] += 1 + top5[target] += 1 + elif np.sum(target == preds[index,:5]): + top5[target] += 1 + + count[target] += 1 + return (top1/(count+1e-12)).mean(), (top5/(count+1e-12)).mean() diff --git a/PaddleCV/Research/webvision2018/utils/utility.py b/PaddleCV/Research/webvision2018/utils/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..c399c288e00da569728af51f0cfddd750a0d4df1 --- /dev/null +++ b/PaddleCV/Research/webvision2018/utils/utility.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import numpy as np +import six + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("------------- Configuration Arguments -------------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%25s : %s" % (arg, value)) + print("----------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs)