diff --git a/PaddleCV/image_classification/eval.py b/PaddleCV/image_classification/eval.py index 49c2862f5335d90deed65bb26082460346dc386f..254ec21d480a9e52e3f25f4cfcd8ab20567ac14f 100644 --- a/PaddleCV/image_classification/eval.py +++ b/PaddleCV/image_classification/eval.py @@ -34,7 +34,7 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet datset") -add_arg('batch_size', int, 256, "batch size on the all devices.") +add_arg('batch_size', int, 256, "batch size on all the devices.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('class_dim', int, 1000, "Class number.") parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model") diff --git a/PaddleCV/image_classification/infer.py b/PaddleCV/image_classification/infer.py index 36ceba6308678ec7b6bef6e69109e640c03c307d..8df267427dac2513472d917d0a6821e2c27f3e7d 100644 --- a/PaddleCV/image_classification/infer.py +++ b/PaddleCV/image_classification/infer.py @@ -23,6 +23,7 @@ import math import numpy as np import argparse import functools +import re import paddle import paddle.fluid as fluid @@ -51,7 +52,7 @@ add_arg('interpolation', int, None, "The interpolation mode" add_arg('padding_type', str, "SAME", "Padding type of convolution") add_arg('use_se', bool, True, "Whether to use Squeeze-and-Excitation module for EfficientNet.") add_arg('image_path', str, None, "single image path") -add_arg('batch_size', int, 8, "batch_size on all devices") +add_arg('batch_size', int, 8, "batch_size on all the devices") add_arg('save_json_path', str, None, "save output to a json file") # yapf: enable @@ -101,8 +102,9 @@ def infer(args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + places = fluid.framework.cuda_places() compiled_program = fluid.compiler.CompiledProgram( - test_program).with_data_parallel() + test_program).with_data_parallel(places=places) fluid.io.load_persistables(exe, args.pretrained_model) if args.save_inference: @@ -119,42 +121,68 @@ def infer(args): imagenet_reader = reader.ImageNetReader() test_reader = imagenet_reader.test(settings=args) - - feeder = fluid.DataFeeder(place=place, feed_list=[image]) - test_reader = feeder.decorate_reader(test_reader, multi_devices=True) + feeder = fluid.DataFeeder(place=places, feed_list=[image]) TOPK = args.topk if os.path.exists(args.class_map_path): print("The map of readable label and numerical label has been found!") - f = open(args.class_map_path) - label_dict = {} - for item in f.readlines(): - key = item.split(" ")[0] - value = [l.replace("\n", "") for l in item.split(" ")[1:]] - label_dict[key] = value + with open(args.class_map_path) as f: + label_dict = {} + strinfo = re.compile(r"\d+ ") + for item in f.readlines(): + key = item.split(" ")[0] + value = [ + strinfo.sub("", l).replace("\n", "") + for l in item.split(", ") + ] + label_dict[key] = value + + info = {} + parallel_data = [] + parallel_id = [] + place_num = paddle.fluid.core.get_cuda_device_count() for batch_id, data in enumerate(test_reader()): - result = exe.run(compiled_program, fetch_list=fetch_list, feed=data) - result = result[0][0] - pred_label = np.argsort(result)[::-1][:TOPK] - - if os.path.exists(args.class_map_path): - readable_pred_label = [] - for label in pred_label: - readable_pred_label.append(label_dict[str(label)]) - print(readable_pred_label) - info = "Test-{0}-score: {1}, class{2} {3}".format( - batch_id, result[pred_label], pred_label, readable_pred_label) - else: - info = "Test-{0}-score: {1}, class{2}".format( - batch_id, result[pred_label], pred_label) - print(info) - if args.save_json_path: - save_json(info, args.save_json_path) - - sys.stdout.flush() - if args.image_path: - os.remove(".tmp.txt") + image_data = [[items[0]] for items in data] + image_id = [items[1] for items in data] + + parallel_id.append(image_id) + parallel_data.append(image_data) + + if place_num == len(parallel_data): + result = exe.run( + compiled_program, + fetch_list=fetch_list, + feed=list(feeder.feed_parallel(parallel_data, place_num))) + for i, res in enumerate(result[0]): + pred_label = np.argsort(res)[::-1][:TOPK] + real_id = str(np.array(parallel_id).flatten()[i]) + _, real_id = os.path.split(real_id) + + if os.path.exists(args.class_map_path): + readable_pred_label = [] + for label in pred_label: + readable_pred_label.append(label_dict[str(label)]) + + info[real_id] = {} + info[real_id]['score'], info[real_id]['class'], info[ + real_id]['class_name'] = str(res[pred_label]), str( + pred_label), readable_pred_label + else: + info[real_id] = {} + info[real_id]['score'], info[real_id]['class'] = str(res[ + pred_label]), str(pred_label) + + print(real_id, info[real_id]) + sys.stdout.flush() + + if args.save_json_path: + save_json(info, args.save_json_path) + + parallel_data = [] + parallel_id = [] + if args.image_path: + os.remove(".tmp.txt") def main(): diff --git a/PaddleCV/image_classification/reader.py b/PaddleCV/image_classification/reader.py index ffde5e353cfd1edd2c2d7779948e42d1d2b829c9..6f7e097c072588eb5c750bcf63feab0abf76bfd3 100644 --- a/PaddleCV/image_classification/reader.py +++ b/PaddleCV/image_classification/reader.py @@ -240,7 +240,7 @@ def process_image(sample, settings, mode, color_jitter, rotate): if mode == 'train' or mode == 'val': return (img, sample[1]) elif mode == 'test': - return (img, ) + return (img, sample[0]) def process_batch_data(input_data, settings, mode, color_jitter, rotate): @@ -262,6 +262,23 @@ class ImageNetReader: assert isinstance(seed, int), "shuffle seed must be int" self.shuffle_seed = seed + def _get_single_card_bs(self, settings, mode): + if settings.use_gpu: + if mode == "val" and settings.test_batch_size: + single_card_bs = settings.test_batch_size // paddle.fluid.core.get_cuda_device_count( + ) + else: + single_card_bs = settings.batch_size // paddle.fluid.core.get_cuda_device_count( + ) + else: + if mode == "val" and settings.test_batch_size: + single_card_bs = settings.test_batch_size // int( + os.environ.get('CPU_NUM', 1)) + else: + single_card_bs = settings.batch_size // int( + os.environ.get('CPU_NUM', 1)) + return single_card_bs + def _reader_creator(self, settings, file_list, @@ -272,12 +289,7 @@ class ImageNetReader: data_dir=None): num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - if settings.use_gpu: - batch_size = settings.batch_size // paddle.fluid.core.get_cuda_device_count( - ) - else: - batch_size = settings.batch_size // int( - os.environ.get('CPU_NUM', 1)) + batch_size = self._get_single_card_bs(settings, mode) def reader(): def read_file_list(): @@ -304,12 +316,15 @@ class ImageNetReader: full_lines = [] for i in range(settings.same_feed): full_lines.append(temp_file) + for line in full_lines: + img_path, label = line.split() img_path = os.path.join(data_dir, img_path) batch_data.append([img_path, int(label)]) if len(batch_data) == batch_size: if mode == 'train' or mode == 'val' or mode == 'test': + yield batch_data batch_data = [] diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index cc3b8499d9d2021a7babcbb4f38f057e22f4855f..4f38cc50a0c22761d80e0cf7312616ed918f4b6a 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -102,8 +102,8 @@ def parse_args(): parser.add_argument('--image_shape', nargs='+', type=int, default=[3, 224, 224], help="The shape of image") add_arg('num_epochs', int, 120, "The number of total epochs.") add_arg('class_dim', int, 1000, "The number of total classes.") - add_arg('batch_size', int, 8, "Minibatch size on all devices.") - add_arg('test_batch_size', int, 16, "Test batch size on all devices.") + add_arg('batch_size', int, 8, "Minibatch size on all the devices.") + add_arg('test_batch_size', int, None, "Test batch size on all the devices.") add_arg('lr', float, 0.1, "The learning rate.") add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.") @@ -287,10 +287,35 @@ def init_model(exe, args, program): print("Finish initing model from %s" % (args.checkpoint)) if args.pretrained_model: + # yapf: disable + + #XXX: should rename all models' final fc layers name as final_fc_weights and final_fc_offset! + final_fc_name = [ + "fc8_weights","fc8_offset", #alexnet + "fc_weights","fc_offset", #darknet, densenet, dpn, hrnet, mobilenet_v3, res2net, res2net_vd, resnext, resnext_vd, xception + #efficient + "out","out_offset", "out1","out1_offset", "out2","out2_offset", #googlenet + "final_fc_weights", "final_fc_offset", #inception_v4 + "fc7_weights", "fc7_offset", #mobilenetv1 + "fc10_weights", "fc10_offset", #mobilenetv2 + "fc_0", #resnet, resnet_vc, resnet_vd + "fc.weight", "fc.bias", #resnext101_wsl + "fc6_weights", "fc6_offset", #se_resnet_vd, se_resnext, se_resnext_vd, shufflenet_v2, shufflenet_v2_swish, + #squeezenet + "fc8_weights", "fc8_offset", #vgg + "fc_bias" #"fc_weights", xception_deeplab + ] + # yapf: enable def is_parameter(var): - return isinstance(var, fluid.framework.Parameter) and ( - not ("fc_0" in var.name)) and os.path.exists( + fc_exclude_flag = False + for item in final_fc_name: + if item in var.name: + fc_exclude_flag = True + + return isinstance( + var, fluid.framework. + Parameter) and not fc_exclude_flag and os.path.exists( os.path.join(args.pretrained_model, var.name)) print("Load pretrain weights from {}, exclude fc layer.".format( @@ -314,7 +339,7 @@ def save_model(args, exe, train_prog, info): def save_json(info, path): """ save eval result or infer result to file as json format. """ - with open(path, 'a') as f: + with open(path, 'w') as f: json.dump(info, f) @@ -493,7 +518,7 @@ def best_strategy_compiled(args, exec_strategy.num_threads = 1 compiled_program = fluid.CompiledProgram(program).with_data_parallel( - loss_name=loss.name if mode == "train" else loss, + loss_name=loss.name if mode == "train" else None, share_vars_from=share_prog if mode == "val" else None, build_strategy=build_strategy, exec_strategy=exec_strategy)