From 08b441af62c281d52d5e84b4056a453a830d7c9b Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Wed, 20 Dec 2017 15:51:47 +0800 Subject: [PATCH] Add xception model for image classification --- image_classification/infer.py | 7 +- image_classification/train.py | 7 +- image_classification/xception.py | 193 +++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 image_classification/xception.py diff --git a/image_classification/infer.py b/image_classification/infer.py index 659c4f2a..1ae5da2c 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -26,7 +26,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'xception' + ]) parser.add_argument( 'params_path', help='The file which stores the parameters') args = parser.parse_args() @@ -49,6 +52,8 @@ def main(): out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM) + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) # load parameters with gzip.open(args.params_path, 'r') as f: diff --git a/image_classification/train.py b/image_classification/train.py index 12a582db..c45eed77 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -19,7 +19,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'xception' + ]) args = parser.parse_args() # PaddlePaddle init @@ -52,6 +55,8 @@ def main(): input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) cost = paddle.layer.classification_cost(input=out, label=lbl) diff --git a/image_classification/xception.py b/image_classification/xception.py new file mode 100644 index 00000000..41c11b83 --- /dev/null +++ b/image_classification/xception.py @@ -0,0 +1,193 @@ +import paddle.v2 as paddle + +__all__ = ['xception'] + + +def img_separable_conv_bn(name, input, num_channels, num_out_channels, + filter_size, stride, padding, act): + conv = paddle.networks.img_separable_conv( + name=name, + input=input, + num_channels=num_channels, + num_out_channels=num_out_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def img_conv_bn(name, input, num_channels, num_filters, filter_size, stride, + padding, act): + conv = paddle.layer.img_conv( + name=name, + input=input, + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def conv_block0(input, + group, + num_channels, + num_filters, + num_filters2=None, + filter_size=3, + pool_padding=0, + entry_relu=True): + if num_filters2 is None: + num_filters2 = num_filters + + if entry_relu: + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + else: + act_input = input + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters2, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + pool0 = paddle.layer.img_pool( + name='xception_block{0}_pool'.format(group), + input=conv1, + pool_size=3, + stride=2, + padding=pool_padding, + num_channels=num_filters2, + pool_type=paddle.pooling.CudnnMax()) + + shortcut = img_conv_bn( + name='xception_block{0}_shortcut'.format(group), + input=input, + num_channels=num_channels, + num_filters=num_filters2, + filter_size=1, + stride=2, + padding=0, + act=paddle.activation.Linear()) + + return paddle.layer.addto( + input=[pool0, shortcut], act=paddle.activation.Linear()) + + +def conv_block1(input, group, num_channels, num_filters, filter_size=3): + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv2 = img_separable_conv_bn( + name='xception_block{0}_conv2'.format(group), + input=conv1, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + + shortcut = input + return paddle.layer.addto( + input=[conv2, shortcut], act=paddle.activation.Linear()) + + +def xception(input, class_dim): + conv = img_conv_bn( + name='xception_conv0', + input=input, + num_channels=3, + num_filters=32, + filter_size=3, + stride=2, + padding=1, + act=paddle.activation.Relu()) + conv = img_conv_bn( + name='xception_conv1', + input=conv, + num_channels=32, + num_filters=64, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = conv_block0( + input=conv, group=2, num_channels=64, num_filters=128, entry_relu=False) + conv = conv_block0(input=conv, group=3, num_channels=128, num_filters=256) + conv = conv_block0(input=conv, group=4, num_channels=256, num_filters=728) + for group in range(5, 13): + conv = conv_block1( + input=conv, group=group, num_channels=728, num_filters=728) + conv = conv_block0( + input=conv, + group=13, + num_channels=728, + num_filters=728, + num_filters2=1024) + conv = img_separable_conv_bn( + name='xception_conv14', + input=conv, + num_channels=1024, + num_out_channels=1536, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = img_separable_conv_bn( + name='xception_conv15', + input=conv, + num_channels=1536, + num_out_channels=2048, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + pool = paddle.layer.img_pool( + name='xception_global_pool', + input=conv, + pool_size=7, + stride=1, + num_channels=2048, + pool_type=paddle.pooling.CudnnAvg()) + out = paddle.layer.fc( + name='xception_fc', + input=pool, + size=class_dim, + act=paddle.activation.Softmax()) + return out -- GitLab