diff --git a/dygraph/ppdet/modeling/architectures/meta_arch.py b/dygraph/ppdet/modeling/architectures/meta_arch.py index a82292a213e3ee9be81ebfab5c4dbdc722f4a79e..fb2f5bd1a5f43c55e4c72cd2823250708371bbab 100644 --- a/dygraph/ppdet/modeling/architectures/meta_arch.py +++ b/dygraph/ppdet/modeling/architectures/meta_arch.py @@ -12,10 +12,14 @@ __all__ = ['BaseArch'] @register class BaseArch(nn.Layer): - def __init__(self): + def __init__(self, data_format='NCHW'): super(BaseArch, self).__init__() + self.data_format = data_format def forward(self, inputs): + if self.data_format == 'NHWC': + image = inputs['image'] + inputs['image'] = paddle.transpose(image, [0, 2, 3, 1]) self.inputs = inputs self.model_arch() diff --git a/dygraph/ppdet/modeling/architectures/yolo.py b/dygraph/ppdet/modeling/architectures/yolo.py index a9d2a6e764530e65eeaf569d6e2110edd2ae9e98..bf6c19ecfe26a07c42e75d38e4acd9b1c443b499 100644 --- a/dygraph/ppdet/modeling/architectures/yolo.py +++ b/dygraph/ppdet/modeling/architectures/yolo.py @@ -11,14 +11,16 @@ __all__ = ['YOLOv3'] @register class YOLOv3(BaseArch): __category__ = 'architecture' + __shared__ = ['data_format'] __inject__ = ['post_process'] def __init__(self, backbone='DarkNet', neck='YOLOv3FPN', yolo_head='YOLOv3Head', - post_process='BBoxPostProcess'): - super(YOLOv3, self).__init__() + post_process='BBoxPostProcess', + data_format='NCHW'): + super(YOLOv3, self).__init__(data_format=data_format) self.backbone = backbone self.neck = neck self.yolo_head = yolo_head diff --git a/dygraph/ppdet/modeling/backbones/darknet.py b/dygraph/ppdet/modeling/backbones/darknet.py index 9dd5e07d1cead6702768e690ceab380cf17fd545..7981306a912b7d893e9c63b76b0aee9247019ba0 100755 --- a/dygraph/ppdet/modeling/backbones/darknet.py +++ b/dygraph/ppdet/modeling/backbones/darknet.py @@ -35,7 +35,8 @@ class ConvBNLayer(nn.Layer): norm_type='bn', norm_decay=0., act="leaky", - name=None): + name=None, + data_format='NCHW'): super(ConvBNLayer, self).__init__() self.conv = nn.Conv2D( @@ -46,9 +47,14 @@ class ConvBNLayer(nn.Layer): padding=padding, groups=groups, weight_attr=ParamAttr(name=name + '.conv.weights'), + data_format=data_format, bias_attr=False) self.batch_norm = batch_norm( - ch_out, norm_type=norm_type, norm_decay=norm_decay, name=name) + ch_out, + norm_type=norm_type, + norm_decay=norm_decay, + name=name, + data_format=data_format) self.act = act def forward(self, inputs): @@ -68,7 +74,8 @@ class DownSample(nn.Layer): padding=1, norm_type='bn', norm_decay=0., - name=None): + name=None, + data_format='NCHW'): super(DownSample, self).__init__() @@ -80,6 +87,7 @@ class DownSample(nn.Layer): padding=padding, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=name) self.ch_out = ch_out @@ -89,7 +97,13 @@ class DownSample(nn.Layer): class BasicBlock(nn.Layer): - def __init__(self, ch_in, ch_out, norm_type='bn', norm_decay=0., name=None): + def __init__(self, + ch_in, + ch_out, + norm_type='bn', + norm_decay=0., + name=None, + data_format='NCHW'): super(BasicBlock, self).__init__() self.conv1 = ConvBNLayer( @@ -100,6 +114,7 @@ class BasicBlock(nn.Layer): padding=0, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=name + '.0') self.conv2 = ConvBNLayer( ch_in=ch_out, @@ -109,6 +124,7 @@ class BasicBlock(nn.Layer): padding=1, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=name + '.1') def forward(self, inputs): @@ -125,7 +141,8 @@ class Blocks(nn.Layer): count, norm_type='bn', norm_decay=0., - name=None): + name=None, + data_format='NCHW'): super(Blocks, self).__init__() self.basicblock0 = BasicBlock( @@ -133,6 +150,7 @@ class Blocks(nn.Layer): ch_out, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=name + '.0') self.res_out_list = [] for i in range(1, count): @@ -144,6 +162,7 @@ class Blocks(nn.Layer): ch_out, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=block_name)) self.res_out_list.append(res_out) self.ch_out = ch_out @@ -161,7 +180,7 @@ DarkNet_cfg = {53: ([1, 2, 8, 8, 4])} @register @serializable class DarkNet(nn.Layer): - __shared__ = ['norm_type'] + __shared__ = ['norm_type', 'data_format'] def __init__(self, depth=53, @@ -169,7 +188,8 @@ class DarkNet(nn.Layer): return_idx=[2, 3, 4], num_stages=5, norm_type='bn', - norm_decay=0.): + norm_decay=0., + data_format='NCHW'): super(DarkNet, self).__init__() self.depth = depth self.freeze_at = freeze_at @@ -185,6 +205,7 @@ class DarkNet(nn.Layer): padding=1, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name='yolo_input') self.downsample0 = DownSample( @@ -192,6 +213,7 @@ class DarkNet(nn.Layer): ch_out=32 * 2, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name='yolo_input.downsample') self._out_channels = [] @@ -208,6 +230,7 @@ class DarkNet(nn.Layer): stage, norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=name)) self.darknet_conv_block_list.append(conv_block) if i in return_idx: @@ -221,6 +244,7 @@ class DarkNet(nn.Layer): ch_out=32 * (2**(i + 2)), norm_type=norm_type, norm_decay=norm_decay, + data_format=data_format, name=down_name)) self.downsample_list.append(downsample) diff --git a/dygraph/ppdet/modeling/heads/yolo_head.py b/dygraph/ppdet/modeling/heads/yolo_head.py index d6453a3a4ba1f2b1f3de5c1ea969f260062328da..723bf4fc6e541021a3d0f7c3a782f843b4272fff 100644 --- a/dygraph/ppdet/modeling/heads/yolo_head.py +++ b/dygraph/ppdet/modeling/heads/yolo_head.py @@ -16,7 +16,7 @@ def _de_sigmoid(x, eps=1e-7): @register class YOLOv3Head(nn.Layer): - __shared__ = ['num_classes'] + __shared__ = ['num_classes', 'data_format'] __inject__ = ['loss'] def __init__(self, @@ -26,7 +26,8 @@ class YOLOv3Head(nn.Layer): num_classes=80, loss='YOLOv3Loss', iou_aware=False, - iou_aware_factor=0.4): + iou_aware_factor=0.4, + data_format='NCHW'): super(YOLOv3Head, self).__init__() self.num_classes = num_classes self.loss = loss @@ -36,6 +37,7 @@ class YOLOv3Head(nn.Layer): self.parse_anchor(anchors, anchor_masks) self.num_outputs = len(self.anchors) + self.data_format = data_format self.yolo_outputs = [] for i in range(len(self.anchors)): @@ -53,6 +55,7 @@ class YOLOv3Head(nn.Layer): kernel_size=1, stride=1, padding=0, + data_format=data_format, weight_attr=ParamAttr(name=name + '.conv.weights'), bias_attr=ParamAttr( name=name + '.conv.bias', regularizer=L2Decay(0.)))) @@ -73,6 +76,8 @@ class YOLOv3Head(nn.Layer): yolo_outputs = [] for i, feat in enumerate(feats): yolo_output = self.yolo_outputs[i](feat) + if self.data_format == 'NHWC': + yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2]) yolo_outputs.append(yolo_output) if self.training: diff --git a/dygraph/ppdet/modeling/necks/yolo_fpn.py b/dygraph/ppdet/modeling/necks/yolo_fpn.py index f89b320532400169e7c357a07f8247700663d9aa..77b6d885df71c105dc391cee6782ff7b863c2673 100644 --- a/dygraph/ppdet/modeling/necks/yolo_fpn.py +++ b/dygraph/ppdet/modeling/necks/yolo_fpn.py @@ -26,7 +26,7 @@ __all__ = ['YOLOv3FPN', 'PPYOLOFPN'] class YoloDetBlock(nn.Layer): - def __init__(self, ch_in, channel, norm_type, name): + def __init__(self, ch_in, channel, norm_type, name, data_format='NCHW'): super(YoloDetBlock, self).__init__() self.ch_in = ch_in self.channel = channel @@ -51,6 +51,7 @@ class YoloDetBlock(nn.Layer): filter_size=filter_size, padding=(filter_size - 1) // 2, norm_type=norm_type, + data_format=data_format, name=name + post_name)) self.tip = ConvBNLayer( @@ -59,6 +60,7 @@ class YoloDetBlock(nn.Layer): filter_size=3, padding=1, norm_type=norm_type, + data_format=data_format, name=name + '.tip') def forward(self, inputs): @@ -68,7 +70,14 @@ class YoloDetBlock(nn.Layer): class SPP(nn.Layer): - def __init__(self, ch_in, ch_out, k, pool_size, norm_type, name): + def __init__(self, + ch_in, + ch_out, + k, + pool_size, + norm_type, + name, + data_format='NCHW'): super(SPP, self).__init__() self.pool = [] for size in pool_size: @@ -78,10 +87,17 @@ class SPP(nn.Layer): kernel_size=size, stride=1, padding=size // 2, + data_format=data_format, ceil_mode=False)) self.pool.append(pool) self.conv = ConvBNLayer( - ch_in, ch_out, k, padding=k // 2, norm_type=norm_type, name=name) + ch_in, + ch_out, + k, + padding=k // 2, + norm_type=norm_type, + name=name, + data_format=data_format) def forward(self, x): outs = [x] @@ -93,30 +109,46 @@ class SPP(nn.Layer): class DropBlock(nn.Layer): - def __init__(self, block_size, keep_prob, name): + def __init__(self, block_size, keep_prob, name, data_format='NCHW'): super(DropBlock, self).__init__() self.block_size = block_size self.keep_prob = keep_prob self.name = name + self.data_format = data_format def forward(self, x): if not self.training or self.keep_prob == 1: return x else: gamma = (1. - self.keep_prob) / (self.block_size**2) - for s in x.shape[2:]: + if self.data_format == 'NCHW': + shape = x.shape[2:] + else: + shape = x.shape[1:3] + for s in shape: gamma *= s / (s - self.block_size + 1) matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype) mask_inv = F.max_pool2d( - matrix, self.block_size, stride=1, padding=self.block_size // 2) + matrix, + self.block_size, + stride=1, + padding=self.block_size // 2, + data_format=self.data_format) mask = 1. - mask_inv y = x * mask * (mask.numel() / mask.sum()) return y class CoordConv(nn.Layer): - def __init__(self, ch_in, ch_out, filter_size, padding, norm_type, name): + def __init__(self, + ch_in, + ch_out, + filter_size, + padding, + norm_type, + name, + data_format='NCHW'): super(CoordConv, self).__init__() self.conv = ConvBNLayer( ch_in + 2, @@ -124,36 +156,53 @@ class CoordConv(nn.Layer): filter_size=filter_size, padding=padding, norm_type=norm_type, + data_format=data_format, name=name) + self.data_format = data_format def forward(self, x): b = x.shape[0] - h = x.shape[2] - w = x.shape[3] + if self.data_format == 'NCHW': + h = x.shape[2] + w = x.shape[3] + else: + h = x.shape[1] + w = x.shape[2] gx = paddle.arange(w, dtype='float32') / (w - 1.) * 2.0 - 1. - gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w]) + if self.data_format == 'NCHW': + gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w]) + else: + gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1]) gx.stop_gradient = True gy = paddle.arange(h, dtype='float32') / (h - 1.) * 2.0 - 1. - gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w]) + if self.data_format == 'NCHW': + gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w]) + else: + gy = gy.reshape([1, h, 1, 1]).expand([b, h, w, 1]) gy.stop_gradient = True - y = paddle.concat([x, gx, gy], axis=1) + if self.data_format == 'NCHW': + y = paddle.concat([x, gx, gy], axis=1) + else: + y = paddle.concat([x, gx, gy], axis=-1) y = self.conv(y) return y class PPYOLODetBlock(nn.Layer): - def __init__(self, cfg, name): + def __init__(self, cfg, name, data_format='NCHW'): super(PPYOLODetBlock, self).__init__() self.conv_module = nn.Sequential() for idx, (conv_name, layer, args, kwargs) in enumerate(cfg[:-1]): - kwargs.update(name='{}.{}'.format(name, conv_name)) + kwargs.update( + name='{}.{}'.format(name, conv_name), data_format=data_format) self.conv_module.add_sublayer(conv_name, layer(*args, **kwargs)) conv_name, layer, args, kwargs = cfg[-1] - kwargs.update(name='{}.{}'.format(name, conv_name)) + kwargs.update( + name='{}.{}'.format(name, conv_name), data_format=data_format) self.tip = layer(*args, **kwargs) def forward(self, inputs): @@ -165,9 +214,12 @@ class PPYOLODetBlock(nn.Layer): @register @serializable class YOLOv3FPN(nn.Layer): - __shared__ = ['norm_type'] + __shared__ = ['norm_type', 'data_format'] - def __init__(self, in_channels=[256, 512, 1024], norm_type='bn'): + def __init__(self, + in_channels=[256, 512, 1024], + norm_type='bn', + data_format='NCHW'): super(YOLOv3FPN, self).__init__() assert len(in_channels) > 0, "in_channels length should > 0" self.in_channels = in_channels @@ -176,6 +228,7 @@ class YOLOv3FPN(nn.Layer): self._out_channels = [] self.yolo_blocks = [] self.routes = [] + self.data_format = data_format for i in range(self.num_blocks): name = 'yolo_block.{}'.format(i) in_channel = in_channels[-i - 1] @@ -187,6 +240,7 @@ class YOLOv3FPN(nn.Layer): in_channel, channel=512 // (2**i), norm_type=norm_type, + data_format=data_format, name=name)) self.yolo_blocks.append(yolo_block) # tip layer output channel doubled @@ -203,6 +257,7 @@ class YOLOv3FPN(nn.Layer): stride=1, padding=0, norm_type=norm_type, + data_format=data_format, name=name)) self.routes.append(route) @@ -212,13 +267,17 @@ class YOLOv3FPN(nn.Layer): yolo_feats = [] for i, block in enumerate(blocks): if i > 0: - block = paddle.concat([route, block], axis=1) + if self.data_format == 'NCHW': + block = paddle.concat([route, block], axis=1) + else: + block = paddle.concat([route, block], axis=-1) route, tip = self.yolo_blocks[i](block) yolo_feats.append(tip) if i < self.num_blocks - 1: route = self.routes[i](route) - route = F.interpolate(route, scale_factor=2.) + route = F.interpolate( + route, scale_factor=2., data_format=self.data_format) return yolo_feats @@ -234,9 +293,13 @@ class YOLOv3FPN(nn.Layer): @register @serializable class PPYOLOFPN(nn.Layer): - __shared__ = ['norm_type'] + __shared__ = ['norm_type', 'data_format'] - def __init__(self, in_channels=[512, 1024, 2048], norm_type='bn', **kwargs): + def __init__(self, + in_channels=[512, 1024, 2048], + norm_type='bn', + data_format='NCHW', + **kwargs): super(PPYOLOFPN, self).__init__() assert len(in_channels) > 0, "in_channels length should > 0" self.in_channels = in_channels @@ -332,6 +395,7 @@ class PPYOLOFPN(nn.Layer): stride=1, padding=0, norm_type=norm_type, + data_format=data_format, name=name)) self.routes.append(route) @@ -341,13 +405,17 @@ class PPYOLOFPN(nn.Layer): yolo_feats = [] for i, block in enumerate(blocks): if i > 0: - block = paddle.concat([route, block], axis=1) + if self.data_format == 'NCHW': + block = paddle.concat([route, block], axis=1) + else: + block = paddle.concat([route, block], axis=-1) route, tip = self.yolo_blocks[i](block) yolo_feats.append(tip) if i < self.num_blocks - 1: route = self.routes[i](route) - route = F.interpolate(route, scale_factor=2.) + route = F.interpolate( + route, scale_factor=2., data_format=self.data_format) return yolo_feats diff --git a/dygraph/ppdet/modeling/ops.py b/dygraph/ppdet/modeling/ops.py index 97b64fb1c2e511d52392ebdf2739bfa1a6604df8..ef961dd1c220f8cee5a46745259add9a3d0cbbe1 100644 --- a/dygraph/ppdet/modeling/ops.py +++ b/dygraph/ppdet/modeling/ops.py @@ -44,7 +44,12 @@ __all__ = [ ] -def batch_norm(ch, norm_type='bn', norm_decay=0., initializer=None, name=None): +def batch_norm(ch, + norm_type='bn', + norm_decay=0., + initializer=None, + name=None, + data_format='NCHW'): bn_name = name + '.bn' if norm_type == 'sync_bn': batch_norm = nn.SyncBatchNorm @@ -58,7 +63,8 @@ def batch_norm(ch, norm_type='bn', norm_decay=0., initializer=None, name=None): initializer=initializer, regularizer=L2Decay(norm_decay)), bias_attr=ParamAttr( - name=bn_name + '.offset', regularizer=L2Decay(norm_decay))) + name=bn_name + '.offset', regularizer=L2Decay(norm_decay)), + data_format=data_format) @paddle.jit.not_to_static