未验证 提交 42c97007 编写于 作者: G Guanghua Yu 提交者: GitHub

fix reparameterization demo train (#1740)

上级 4389a804
import os
import math
import random
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
from paddle.io import Dataset
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
RESIZE_DIM = 256
THREAD = 16
BUF_SIZE = 10240
DATA_DIR = 'data/ILSVRC2012/'
DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * np.random.uniform(
scale_min, scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.LANCZOS)
return img
def rotate_image(img):
angle = np.random.randint(-10, 11)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate, crop_size, resize_size):
img_path = sample[0]
try:
img = Image.open(img_path)
except:
print(img_path, "not exists!")
return None
if mode == 'train':
if rotate: img = rotate_image(img)
img = random_crop(img, crop_size)
else:
img = resize_short(img, target_size=resize_size)
img = crop_image(img, target_size=crop_size, center=True)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return img, sample[1]
elif mode == 'test':
return [img]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR,
batch_size=1):
def reader():
try:
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
except Exception as e:
print("Reader failed!\n{}".format(str(e)))
os._exit(1)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'test_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
class ImageNetDataset(Dataset):
def __init__(self,
data_dir=DATA_DIR,
mode='train',
crop_size=DATA_DIM,
resize_size=RESIZE_DIM):
super(ImageNetDataset, self).__init__()
self.data_dir = data_dir
self.crop_size = crop_size
self.resize_size = resize_size
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
lines = full_lines
self.data = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.data = [line.split() for line in lines]
def __getitem__(self, index):
sample = self.data[index]
data_path = os.path.join(self.data_dir, sample[0])
if self.mode == 'train':
data, label = process_image(
[data_path, sample[1]],
mode='train',
color_jitter=False,
rotate=False,
crop_size=self.crop_size,
resize_size=self.resize_size)
return data, np.array([label]).astype('int64')
elif self.mode == 'val':
data, label = process_image(
[data_path, sample[1]],
mode='val',
color_jitter=False,
rotate=False,
crop_size=self.crop_size,
resize_size=self.resize_size)
return data, np.array([label]).astype('int64')
elif self.mode == 'test':
data = process_image(
[data_path, sample[1]],
mode='test',
color_jitter=False,
rotate=False,
crop_size=self.crop_size,
resize_size=self.resize_size)
return data
def __len__(self):
return len(self.data)
...@@ -26,6 +26,8 @@ import math ...@@ -26,6 +26,8 @@ import math
import time import time
import random import random
import numpy as np import numpy as np
import distutils.util
import six
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from paddle.static import load_program_state from paddle.static import load_program_state
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
...@@ -35,31 +37,49 @@ from paddleslim.dygraph.rep import Reparameter, DBBRepConfig, ACBRepConfig ...@@ -35,31 +37,49 @@ from paddleslim.dygraph.rep import Reparameter, DBBRepConfig, ACBRepConfig
sys.path.append(os.path.join(os.path.dirname("__file__"))) sys.path.append(os.path.join(os.path.dirname("__file__")))
from optimizer import create_optimizer from optimizer import create_optimizer
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) def print_arguments(args):
# yapf: disable """Print argparse's arguments.
add_arg('batch_size', int, 64, "Single Card Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") Usage:
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") .. code-block:: python
add_arg('l2_decay', float, 0.00003, "The l2_decay parameter.")
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.") parser = argparse.ArgumentParser()
add_arg('use_pact', bool, False, "Whether to use PACT method.") parser.add_argument("name", default="Jonh", type=str, help="User name.")
add_arg('ce_test', bool, False, "Whether to CE test.") args = parser.parse_args()
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") print_arguments(args)
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.") :param args: Input argparse.Namespace for printing.
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'") :type args: argparse.Namespace
add_arg('log_period', int, 10, "Log period in batches.") """
add_arg('model_save_dir', str, "./output_models", "model save directory.") print("----------- Configuration Arguments -----------")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") for arg, value in sorted(six.iteritems(vars(args))):
# yapf: enable print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def load_dygraph_pretrain(model, path=None, load_static_weights=False): def load_dygraph_pretrain(model, path=None, load_static_weights=False):
...@@ -110,8 +130,9 @@ def train(args): ...@@ -110,8 +130,9 @@ def train(args):
args.total_images = 50000 args.total_images = 50000
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train') train_dataset = reader.ImageNetDataset(
val_dataset = reader.ImageNetDataset(mode='val') data_dir=args.data_dir, mode='train')
val_dataset = reader.ImageNetDataset(data_dir=args.data_dir, mode='val')
class_dim = 1000 class_dim = 1000
image_shape = "3,224,224" image_shape = "3,224,224"
else: else:
...@@ -313,11 +334,31 @@ def train(args): ...@@ -313,11 +334,31 @@ def train(args):
]) ])
def main(): def main(parser):
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
train(args) train(args)
if __name__ == '__main__': if __name__ == '__main__':
main() parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Single Card Minibatch size.")
add_arg('data_dir', str, "dataset/ILSVRC2012/", "Single Card Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 0.00003, "The l2_decay parameter.")
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
add_arg('use_pact', bool, False, "Whether to use PACT method.")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('model_save_dir', str, "./output_models", "model save directory.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable
main(parser)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册