提交 3e8128cf 编写于 作者: Y Yang Zhang

Refactor `resnet` demo

上级 1faf669a
...@@ -27,88 +27,11 @@ import paddle ...@@ -27,88 +27,11 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model, CrossEntropy from model import Model, CrossEntropy
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i: i + c, j: j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i: i + h, j: j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
return img[:, ::-1, :]
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
def image_folder(path, shuffle=False):
valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp')
classes = [d for d in os.listdir(path) if
os.path.isdir(os.path.join(path, d))]
classes.sort()
class_map = {cls: idx for idx, cls in enumerate(classes)}
samples = []
for dir in sorted(class_map.keys()):
d = os.path.join(path, dir)
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
p = os.path.join(root, fname)
if os.path.splitext(p)[1].lower() in valid_ext:
samples.append((p, class_map[dir]))
if shuffle:
random.shuffle(samples)
def iterator():
for s in samples:
yield s
return iterator
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
...@@ -204,8 +127,8 @@ class ResNet(Model): ...@@ -204,8 +127,8 @@ class ResNet(Model):
layer_config.keys(), depth) layer_config.keys(), depth)
layers = layer_config[depth] layers = layer_config[depth]
num_channels = [64, 256, 512, 1024] num_in = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512] num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer( self.conv = ConvBNLayer(
num_channels=3, num_channels=3,
...@@ -219,26 +142,28 @@ class ResNet(Model): ...@@ -219,26 +142,28 @@ class ResNet(Model):
pool_padding=1, pool_padding=1,
pool_type='max') pool_type='max')
self.blocks = [] self.layers = []
for b in range(len(layers)): for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False shortcut = False
for i in range(layers[b]): for b in range(num_blocks):
block = self.add_sublayer( block = BottleneckBlock(
'layer_{}_{}'.format(b, i), num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
BottleneckBlock( num_filters=num_out[idx],
num_channels=num_channels[b] stride=2 if b == 0 and idx != 0 else 1,
if i == 0 else num_filters[b] * 4, shortcut=shortcut)
num_filters=num_filters[b], blocks.append(block)
stride=2 if i == 0 and b != 0 else 1,
shortcut=shortcut))
self.blocks.append(block)
shortcut = True shortcut = True
layer = self.add_sublayer(
"layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D( self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True) pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0) stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_filters[len(num_filters) - 1] * 4 * 1 * 1 self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(self.fc_input_dim, self.fc = Linear(self.fc_input_dim,
num_classes, num_classes,
act='softmax', act='softmax',
...@@ -249,8 +174,8 @@ class ResNet(Model): ...@@ -249,8 +174,8 @@ class ResNet(Model):
def forward(self, inputs): def forward(self, inputs):
x = self.conv(inputs) x = self.conv(inputs)
x = self.pool(x) x = self.pool(x)
for block in self.blocks: for layer in self.layers:
x = block(x) x = layer(x)
x = self.global_pool(x) x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim]) x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x) x = self.fc(x)
...@@ -289,11 +214,88 @@ def accuracy(pred, label, topk=(1, )): ...@@ -289,11 +214,88 @@ def accuracy(pred, label, topk=(1, )):
return res return res
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i: i + c, j: j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i: i + h, j: j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
return img[:, ::-1, :]
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
def image_folder(path, shuffle=False):
valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp')
classes = [d for d in os.listdir(path) if
os.path.isdir(os.path.join(path, d))]
classes.sort()
class_map = {cls: idx for idx, cls in enumerate(classes)}
samples = []
for dir in sorted(class_map.keys()):
d = os.path.join(path, dir)
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
p = os.path.join(root, fname)
if os.path.splitext(p)[1].lower() in valid_ext:
samples.append((p, class_map[dir]))
def iterator():
if shuffle:
random.shuffle(samples)
for s in samples:
yield s
return iterator
def run(model, loader, mode='train'): def run(model, loader, mode='train'):
total_loss = 0.0 total_loss = 0.0
total_acc1 = 0.0 total_acc1 = 0.0
total_acc5 = 0.0 total_acc5 = 0.0
num_steps = 0
device_ids = list(range(FLAGS.num_devices)) device_ids = list(range(FLAGS.num_devices))
for idx, batch in enumerate(loader()): for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)( outputs, losses = getattr(model, mode)(
...@@ -303,12 +305,10 @@ def run(model, loader, mode='train'): ...@@ -303,12 +305,10 @@ def run(model, loader, mode='train'):
total_loss += np.sum(losses) total_loss += np.sum(losses)
total_acc1 += top1 total_acc1 += top1
total_acc5 += top5 total_acc5 += top5
num_steps += 1
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top5: {:0.3f}%".format( print("{:04d}: loss {:0.3f} top1: {:0.3f}% top5: {:0.3f}%".format(
idx, total_loss / num_steps, idx, total_loss / (idx + 1), total_acc1 / (idx + 1),
total_acc1 / num_steps, total_acc5 / num_steps)) total_acc5 / (idx + 1)))
num_steps += 1
def main(): def main():
...@@ -357,8 +357,8 @@ def main(): ...@@ -357,8 +357,8 @@ def main():
with guard: with guard:
model = ResNet() model = ResNet()
sgd = make_optimizer(parameter_list=model.parameters()) optim = make_optimizer(parameter_list=model.parameters())
model.prepare(sgd, CrossEntropy()) model.prepare(optim, CrossEntropy())
for e in range(epoch): for e in range(epoch):
print("======== train epoch {} ========".format(e)) print("======== train epoch {} ========".format(e))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册