From 42c97007f98097a0e819f6a07eabe32f482e869c Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 27 Apr 2023 11:55:25 +0800 Subject: [PATCH] fix reparameterization demo train (#1740) --- example/reparameterization/imagenet_reader.py | 245 ++++++++++++++++++ example/reparameterization/train.py | 93 +++++-- 2 files changed, 312 insertions(+), 26 deletions(-) create mode 100644 example/reparameterization/imagenet_reader.py diff --git a/example/reparameterization/imagenet_reader.py b/example/reparameterization/imagenet_reader.py new file mode 100644 index 00000000..06a9d001 --- /dev/null +++ b/example/reparameterization/imagenet_reader.py @@ -0,0 +1,245 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +from PIL import Image, ImageEnhance +from paddle.io import Dataset + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 +RESIZE_DIM = 256 + +THREAD = 16 +BUF_SIZE = 10240 + +DATA_DIR = 'data/ILSVRC2012/' +DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR) + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.size[0]) / img.size[1]) / (w**2), + (float(img.size[1]) / img.size[0]) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.size[0] * img.size[1] * np.random.uniform( + scale_min, scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img.crop((i, j, i + w, j + h)) + img = img.resize((size, size), Image.LANCZOS) + return img + + +def rotate_image(img): + angle = np.random.randint(-10, 11) + img = img.rotate(angle) + return img + + +def distort_color(img): + def random_brightness(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Color(img).enhance(e) + + ops = [random_brightness, random_contrast, random_color] + np.random.shuffle(ops) + + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + + return img + + +def process_image(sample, mode, color_jitter, rotate, crop_size, resize_size): + img_path = sample[0] + + try: + img = Image.open(img_path) + except: + print(img_path, "not exists!") + return None + if mode == 'train': + if rotate: img = rotate_image(img) + img = random_crop(img, crop_size) + else: + img = resize_short(img, target_size=resize_size) + img = crop_image(img, target_size=crop_size, center=True) + if mode == 'train': + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return img, sample[1] + elif mode == 'test': + return [img] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR, + batch_size=1): + def reader(): + try: + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + lines = full_lines + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + except Exception as e: + print("Reader failed!\n{}".format(str(e))) + os._exit(1) + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') + return _reader_creator( + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=data_dir) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'test_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) + + +class ImageNetDataset(Dataset): + def __init__(self, + data_dir=DATA_DIR, + mode='train', + crop_size=DATA_DIM, + resize_size=RESIZE_DIM): + super(ImageNetDataset, self).__init__() + self.data_dir = data_dir + self.crop_size = crop_size + self.resize_size = resize_size + train_file_list = os.path.join(data_dir, 'train_list.txt') + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.mode = mode + if mode == 'train': + with open(train_file_list) as flist: + full_lines = [line.strip() for line in flist] + np.random.shuffle(full_lines) + lines = full_lines + self.data = [line.split() for line in lines] + else: + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + sample = self.data[index] + data_path = os.path.join(self.data_dir, sample[0]) + if self.mode == 'train': + data, label = process_image( + [data_path, sample[1]], + mode='train', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data, np.array([label]).astype('int64') + elif self.mode == 'val': + data, label = process_image( + [data_path, sample[1]], + mode='val', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data, np.array([label]).astype('int64') + elif self.mode == 'test': + data = process_image( + [data_path, sample[1]], + mode='test', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data + + def __len__(self): + return len(self.data) diff --git a/example/reparameterization/train.py b/example/reparameterization/train.py index 50f55909..e5c93fc8 100644 --- a/example/reparameterization/train.py +++ b/example/reparameterization/train.py @@ -26,6 +26,8 @@ import math import time import random import numpy as np +import distutils.util +import six from paddle.distributed import ParallelEnv from paddle.static import load_program_state from paddle.vision.models import mobilenet_v1 @@ -35,31 +37,49 @@ from paddleslim.dygraph.rep import Reparameter, DBBRepConfig, ACBRepConfig sys.path.append(os.path.join(os.path.dirname("__file__"))) from optimizer import create_optimizer -sys.path.append( - os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) -from utility import add_arguments, print_arguments _logger = get_logger(__name__, level=logging.INFO) -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('batch_size', int, 64, "Single Card Minibatch size.") -add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") -add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") -add_arg('l2_decay', float, 0.00003, "The l2_decay parameter.") -add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.") -add_arg('use_pact', bool, False, "Whether to use PACT method.") -add_arg('ce_test', bool, False, "Whether to CE test.") -add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") -add_arg('num_epochs', int, 120, "The number of total epochs.") -add_arg('total_images', int, 1281167, "The number of total training images.") -add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'") -add_arg('log_period', int, 10, "Log period in batches.") -add_arg('model_save_dir', str, "./output_models", "model save directory.") -parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") -# yapf: enable + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) def load_dygraph_pretrain(model, path=None, load_static_weights=False): @@ -110,8 +130,9 @@ def train(args): args.total_images = 50000 elif args.data == "imagenet": import imagenet_reader as reader - train_dataset = reader.ImageNetDataset(mode='train') - val_dataset = reader.ImageNetDataset(mode='val') + train_dataset = reader.ImageNetDataset( + data_dir=args.data_dir, mode='train') + val_dataset = reader.ImageNetDataset(data_dir=args.data_dir, mode='val') class_dim = 1000 image_shape = "3,224,224" else: @@ -313,11 +334,31 @@ def train(args): ]) -def main(): +def main(parser): args = parser.parse_args() print_arguments(args) train(args) if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('batch_size', int, 64, "Single Card Minibatch size.") + add_arg('data_dir', str, "dataset/ILSVRC2012/", "Single Card Minibatch size.") + add_arg('use_gpu', bool, True, "Whether to use GPU or not.") + add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") + add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") + add_arg('l2_decay', float, 0.00003, "The l2_decay parameter.") + add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.") + add_arg('use_pact', bool, False, "Whether to use PACT method.") + add_arg('ce_test', bool, False, "Whether to CE test.") + add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") + add_arg('num_epochs', int, 120, "The number of total epochs.") + add_arg('total_images', int, 1281167, "The number of total training images.") + add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'") + add_arg('log_period', int, 10, "Log period in batches.") + add_arg('model_save_dir', str, "./output_models", "model save directory.") + parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") + # yapf: enable + main(parser) -- GitLab