diff --git a/configs/mask_rcnn_r50_fpn_1x.yml b/configs/mask_rcnn_r50_fpn_1x.yml index a1c90e3c02909d426d722d9b1004e9d76e6f8fbc..30b97fde037af92409d283a3b80ce35afa806a79 100644 --- a/configs/mask_rcnn_r50_fpn_1x.yml +++ b/configs/mask_rcnn_r50_fpn_1x.yml @@ -130,7 +130,7 @@ LearningRate: gamma: 0.1 milestones: [120000, 160000] - !LinearWarmup - start_factor: 0.3333 + start_factor: 0.3333333 steps: 500 OptimizerBuilder: diff --git a/configs/yolov3_darknet.yml b/configs/yolov3_darknet.yml index 7a1215def9fabffef2c2a4de6552636abddf27bd..15e899a9285f048bc32835952d0c604fb8699790 100644 --- a/configs/yolov3_darknet.yml +++ b/configs/yolov3_darknet.yml @@ -3,13 +3,13 @@ use_gpu: true max_iters: 500000 log_smooth_window: 20 save_dir: output -snapshot_iter: 10000 +snapshot_iter: 50000 metric: COCO -pretrain_weights: https://paddlemodels.bj.bcebos.com/yolo/darknet53.pdparams +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar weights: output/yolov3_darknet/model_final num_classes: 80 use_fine_grained_loss: false -open_debug: False +load_static_weights: True YOLOv3: anchor: AnchorYOLO @@ -18,11 +18,15 @@ YOLOv3: DarkNet: depth: 53 + return_idx: [2, 3, 4] YOLOv3Head: yolo_feat: name: YOLOFeat feat_in_list: [1024, 768, 384] + ignore_thresh: 0.7 + downsample: 32 + label_smooth: true anchor_per_position: 3 AnchorYOLO: @@ -30,11 +34,6 @@ AnchorYOLO: name: AnchorGeneratorYOLO anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - anchor_target_generator: - name: AnchorTargetGeneratorYOLO - ignore_thresh: 0.7 - downsample_ratio: 32 - label_smooth: true anchor_post_process: name: BBoxPostProcessYOLO # decode -> clip diff --git a/configs/yolov3_reader.yml b/configs/yolov3_reader.yml index 2a8463f1e6c2cb598ea4a55c6289f5b04b290d4a..5e11364b0ad831e68513f60ea68ca511e1018a8b 100644 --- a/configs/yolov3_reader.yml +++ b/configs/yolov3_reader.yml @@ -27,8 +27,8 @@ TrainReader: - !BboxXYXY2XYWH {} batch_transforms: - !RandomShape - sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] - random_inter: True + sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_inter: True - !NormalizeImage mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] @@ -50,8 +50,8 @@ TrainReader: shuffle: true mixup_epoch: 250 drop_last: true - worker_num: 8 - bufsize: 16 + worker_num: 4 + bufsize: 4 use_process: true @@ -81,7 +81,7 @@ EvalReader: - !Permute to_bgr: false channel_first: True - batch_size: 8 + batch_size: 1 drop_empty: false worker_num: 8 bufsize: 16 diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 7d808b58986a3b0941c007c3bd6621f77ba9c5f5..5cd9f3971136aaa108fda3db39b0615408f951cb 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -201,7 +201,7 @@ class Reader(object): use_fine_grained_loss=False, num_classes=80, bufsize=-1, - memsize='3G', + memsize='500M', inputs_def=None, devices_num=1): self._dataset = dataset diff --git a/ppdet/modeling/architecture/yolo.py b/ppdet/modeling/architecture/yolo.py index c26ba8f5b2bd67d8ae6ffe2bcf549c570e5341d3..f5045f60e245a3acc0f1fe80187af30c0b5e3374 100644 --- a/ppdet/modeling/architecture/yolo.py +++ b/ppdet/modeling/architecture/yolo.py @@ -17,38 +17,34 @@ class YOLOv3(BaseArch): 'yolo_head', ] - def __init__(self, anchor, backbone, yolo_head, *args, **kwargs): - super(YOLOv3, self).__init__(*args, **kwargs) + def __init__(self, anchor, backbone, yolo_head): + super(YOLOv3, self).__init__() self.anchor = anchor self.backbone = backbone self.yolo_head = yolo_head def model_arch(self, ): # Backbone - bb_out = self.backbone(self.gbd) - self.gbd.update(bb_out) + body_feats = self.backbone(self.inputs) # YOLO Head - yolo_head_out = self.yolo_head(self.gbd) - self.gbd.update(yolo_head_out) + self.yolo_head_out = self.yolo_head(body_feats) # Anchor - anchor_out = self.anchor(self.gbd) - self.gbd.update(anchor_out) - - if self.gbd['mode'] == 'infer': - bbox_out = self.anchor.post_process(self.gbd) - self.gbd.update(bbox_out) + self.anchors, self.anchor_masks, self.mask_anchors = self.anchor() def loss(self, ): - yolo_loss = self.yolo_head.loss(self.gbd) - out = {'loss': yolo_loss} - return out + yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_out, + self.anchors, self.anchor_masks, + self.mask_anchors) + return yolo_loss def infer(self, ): + bbox, bbox_num = self.anchor.post_process( + self.inputs['im_size'], self.yolo_head_out, self.mask_anchors) outs = { - "bbox": self.gbd['predicted_bbox'].numpy(), - "bbox_nums": self.gbd['predicted_bbox_nums'], - 'im_id': self.gbd['im_id'].numpy() + "bbox": bbox.numpy(), + "bbox_num": bbox_num, + 'im_id': self.inputs['im_id'].numpy() } return outs diff --git a/ppdet/modeling/backbone/darknet.py b/ppdet/modeling/backbone/darknet.py index ed34f0763b55af8fac482b4cbecc8524dc378870..cce2647237890c95dc1d8e9329b55b7cf653f45e 100755 --- a/ppdet/modeling/backbone/darknet.py +++ b/ppdet/modeling/backbone/darknet.py @@ -16,7 +16,8 @@ class ConvBNLayer(Layer): stride=1, groups=1, padding=0, - act="leaky"): + act="leaky", + name=None): super(ConvBNLayer, self).__init__() self.conv = Conv2D( @@ -26,18 +27,18 @@ class ConvBNLayer(Layer): stride=stride, padding=padding, groups=groups, - param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02)), + param_attr=ParamAttr(name=name + '.conv.weights'), bias_attr=False, act=None) + bn_name = name + '.bn' self.batch_norm = BatchNorm( num_channels=ch_out, param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02), - regularizer=L2Decay(0.)), + name=bn_name + '.scale', regularizer=L2Decay(0.)), bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.))) + name=bn_name + '.offset', regularizer=L2Decay(0.)), + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') self.act = act @@ -50,7 +51,13 @@ class ConvBNLayer(Layer): class DownSample(Layer): - def __init__(self, ch_in, ch_out, filter_size=3, stride=2, padding=1): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=2, + padding=1, + name=None): super(DownSample, self).__init__() @@ -59,7 +66,8 @@ class DownSample(Layer): ch_out=ch_out, filter_size=filter_size, stride=stride, - padding=padding) + padding=padding, + name=name) self.ch_out = ch_out def forward(self, inputs): @@ -68,13 +76,23 @@ class DownSample(Layer): class BasicBlock(Layer): - def __init__(self, ch_in, ch_out): + def __init__(self, ch_in, ch_out, name=None): super(BasicBlock, self).__init__() self.conv1 = ConvBNLayer( - ch_in=ch_in, ch_out=ch_out, filter_size=1, stride=1, padding=0) + ch_in=ch_in, + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0, + name=name + '.0') self.conv2 = ConvBNLayer( - ch_in=ch_out, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1) + ch_in=ch_out, + ch_out=ch_out * 2, + filter_size=3, + stride=1, + padding=1, + name=name + '.1') def forward(self, inputs): conv1 = self.conv1(inputs) @@ -84,14 +102,16 @@ class BasicBlock(Layer): class Blocks(Layer): - def __init__(self, ch_in, ch_out, count): + def __init__(self, ch_in, ch_out, count, name=None): super(Blocks, self).__init__() - self.basicblock0 = BasicBlock(ch_in, ch_out) + self.basicblock0 = BasicBlock(ch_in, ch_out, name=name + '.0') self.res_out_list = [] for i in range(1, count): - res_out = self.add_sublayer("basic_block_%d" % (i), - BasicBlock(ch_out * 2, ch_out)) + block_name = '{}.{}'.format(name, i) + res_out = self.add_sublayer( + block_name, BasicBlock( + ch_out * 2, ch_out, name=block_name)) self.res_out_list.append(res_out) self.ch_out = ch_out @@ -108,31 +128,46 @@ DarkNet_cfg = {53: ([1, 2, 8, 8, 4])} @register @serializable class DarkNet(Layer): - def __init__(self, depth=53, mode='train'): + def __init__(self, + depth=53, + freeze_at=-1, + return_idx=[2, 3, 4], + num_stages=5): super(DarkNet, self).__init__() self.depth = depth - self.mode = mode - self.stages = DarkNet_cfg[self.depth][0:5] + self.freeze_at = freeze_at + self.return_idx = return_idx + self.num_stages = num_stages + self.stages = DarkNet_cfg[self.depth][0:num_stages] self.conv0 = ConvBNLayer( - ch_in=3, ch_out=32, filter_size=3, stride=1, padding=1) + ch_in=3, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + name='yolo_input') - self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2) + self.downsample0 = DownSample( + ch_in=32, ch_out=32 * 2, name='yolo_input.downsample') - self.darknet53_conv_block_list = [] + self.darknet_conv_block_list = [] self.downsample_list = [] ch_in = [64, 128, 256, 512, 1024] for i, stage in enumerate(self.stages): - conv_block = self.add_sublayer("stage_%d" % (i), - Blocks( - int(ch_in[i]), 32 * (2**i), - stage)) - self.darknet53_conv_block_list.append(conv_block) - for i in range(len(self.stages) - 1): + name = 'stage.{}'.format(i) + conv_block = self.add_sublayer( + name, Blocks( + int(ch_in[i]), 32 * (2**i), stage, name=name)) + self.darknet_conv_block_list.append(conv_block) + for i in range(num_stages - 1): + down_name = 'stage.{}.downsample'.format(i) downsample = self.add_sublayer( - "stage_%d_downsample" % i, + down_name, DownSample( - ch_in=32 * (2**(i + 1)), ch_out=32 * (2**(i + 2)))) + ch_in=32 * (2**(i + 1)), + ch_out=32 * (2**(i + 2)), + name=down_name)) self.downsample_list.append(downsample) def forward(self, inputs): @@ -141,10 +176,12 @@ class DarkNet(Layer): out = self.conv0(x) out = self.downsample0(out) blocks = [] - for i, conv_block_i in enumerate(self.darknet53_conv_block_list): + for i, conv_block_i in enumerate(self.darknet_conv_block_list): out = conv_block_i(out) - blocks.append(out) - if i < len(self.stages) - 1: + if i == self.freeze_at: + out.stop_gradient = True + if i in self.return_idx: + blocks.append(out) + if i < self.num_stages - 1: out = self.downsample_list[i](out) - outs = {'darknet_outs': blocks[-1:-4:-1]} - return outs + return blocks diff --git a/ppdet/modeling/bbox.py b/ppdet/modeling/bbox.py index 235f4dc55eb54bf42297e8f0768307c98419eed6..fefc3b51efc9da7e9cb4c2bf6907be95e9106140 100644 --- a/ppdet/modeling/bbox.py +++ b/ppdet/modeling/bbox.py @@ -79,27 +79,25 @@ class BBoxPostProcessYOLO(object): self.decode = decode self.clip = clip - def __call__(self, inputs): + def __call__(self, im_size, yolo_head_out, mask_anchors): # TODO: split yolo_box into 2 steps # decode # clip boxes_list = [] scores_list = [] - for i, out in enumerate(inputs['yolo_outs']): - boxes, scores = self.yolo_box(out, inputs['im_size'], - inputs['mask_anchors'][i], i, - "yolo_box_" + str(i)) + for i, head_out in enumerate(yolo_head_out): + boxes, scores = self.yolo_box(head_out, im_size, mask_anchors[i], + self.num_classes, i) boxes_list.append(boxes) scores_list.append(fluid.layers.transpose(scores, perm=[0, 2, 1])) yolo_boxes = fluid.layers.concat(boxes_list, axis=1) yolo_scores = fluid.layers.concat(scores_list, axis=2) - nmsed_bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores) + bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores) # TODO: parse the lod of nmsed_bbox # default batch size is 1 - bbox_nums = np.array([0, int(nmsed_bbox.shape[0])], dtype=np.int32) - outs = {"predicted_bbox_nums": bbox_nums, "predicted_bbox": nmsed_bbox} - return outs + bbox_num = np.array([int(bbox.shape[0])], dtype=np.int32) + return bbox, bbox_num @register @@ -168,32 +166,18 @@ class AnchorRPN(object): @register class AnchorYOLO(object): - __inject__ = [ - 'anchor_generator', 'anchor_target_generator', 'anchor_post_process' - ] + __inject__ = ['anchor_generator', 'anchor_post_process'] - def __init__(self, anchor_generator, anchor_target_generator, - anchor_post_process): + def __init__(self, anchor_generator, anchor_post_process): super(AnchorYOLO, self).__init__() self.anchor_generator = anchor_generator - self.anchor_target_generator = anchor_target_generator self.anchor_post_process = anchor_post_process - def __call__(self, inputs): - outs = self.generate_anchors(inputs) - return outs - - def generate_anchors(self, inputs): - outs = self.anchor_generator(inputs['yolo_outs']) - outs['anchor_module'] = self - return outs - - def generate_anchors_target(self, inputs): - outs = self.anchor_target_generator() - return outs + def __call__(self): + return self.anchor_generator() - def post_process(self, inputs): - return self.anchor_post_process(inputs) + def post_process(self, im_size, yolo_head_out, mask_anchors): + return self.anchor_post_process(im_size, yolo_head_out, mask_anchors) @register diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index 3ab92daa4096761987db9b2b52965eb3da72fa67..47f202a1b99d9da73962ed13ebbe157eccfd738e 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -37,7 +37,7 @@ class MaskFeat(Layer): mask_conv.add_sublayer( conv_name, Conv2D( - num_channels=feat_in if j == 1 else feat_out, + num_channels=feat_in if j == 0 else feat_out, num_filters=feat_out, filter_size=3, act='relu', diff --git a/ppdet/modeling/head/yolo_head.py b/ppdet/modeling/head/yolo_head.py index 56b1bafab6011326ab3601a31fd54f7bc5fdfcd6..cae1f4024e889464d9401d5a0b1d5ca57e1f1b56 100644 --- a/ppdet/modeling/head/yolo_head.py +++ b/ppdet/modeling/head/yolo_head.py @@ -1,150 +1,135 @@ import paddle.fluid as fluid +import paddle from paddle.fluid.dygraph import Layer from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Normal from paddle.fluid.regularizer import L2Decay from paddle.fluid.dygraph.nn import Conv2D, BatchNorm +from paddle.fluid.dygraph import Sequential from ppdet.core.workspace import register from ..backbone.darknet import ConvBNLayer class YoloDetBlock(Layer): - def __init__(self, ch_in, channel): + def __init__(self, ch_in, channel, name): super(YoloDetBlock, self).__init__() - + self.ch_in = ch_in + self.channel = channel assert channel % 2 == 0, \ "channel {} cannot be divided by 2".format(channel) - - self.conv0 = ConvBNLayer( - ch_in=ch_in, ch_out=channel, filter_size=1, stride=1, padding=0) - - self.conv1 = ConvBNLayer( - ch_in=channel, - ch_out=channel * 2, - filter_size=3, - stride=1, - padding=1) - - self.conv2 = ConvBNLayer( - ch_in=channel * 2, - ch_out=channel, - filter_size=1, - stride=1, - padding=0) - - self.conv3 = ConvBNLayer( - ch_in=channel, - ch_out=channel * 2, - filter_size=3, - stride=1, - padding=1) - - self.route = ConvBNLayer( - ch_in=channel * 2, - ch_out=channel, - filter_size=1, - stride=1, - padding=0) + conv_def = [ + ['conv0', ch_in, channel, 1, '.0.0'], + ['conv1', channel, channel * 2, 3, '.0.1'], + ['conv2', channel * 2, channel, 1, '.1.0'], + ['conv3', channel, channel * 2, 3, '.1.1'], + ['route', channel * 2, channel, 1, '.2'], + #['tip', channel, channel * 2, 3], + ] + + self.conv_module = Sequential() + for idx, (conv_name, ch_in, ch_out, filter_size, + post_name) in enumerate(conv_def): + self.conv_module.add_sublayer( + conv_name, + ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=filter_size, + padding=(filter_size - 1) // 2, + name=name + post_name)) self.tip = ConvBNLayer( ch_in=channel, ch_out=channel * 2, filter_size=3, - stride=1, - padding=1) + padding=1, + name=name + '.tip') def forward(self, inputs): - out = self.conv0(inputs) - out = self.conv1(out) - out = self.conv2(out) - out = self.conv3(out) - route = self.route(out) + route = self.conv_module(inputs) tip = self.tip(route) return route, tip -class Upsample(Layer): - def __init__(self, scale=2): - super(Upsample, self).__init__() - self.scale = scale - - def forward(self, inputs): - # get dynamic upsample output shape - shape_nchw = fluid.layers.shape(inputs) - shape_hw = fluid.layers.slice( - shape_nchw, axes=[0], starts=[2], ends=[4]) - shape_hw.stop_gradient = True - in_shape = fluid.layers.cast(shape_hw, dtype='int32') - out_shape = in_shape * self.scale - out_shape.stop_gradient = True - - # reisze by actual_shape - out = fluid.layers.resize_nearest( - input=inputs, scale=self.scale, actual_shape=out_shape) - return out - - @register class YOLOFeat(Layer): - def __init__(self, feat_in_list=[1024, 768, 384]): + __shared__ = ['num_levels'] + + def __init__(self, feat_in_list=[1024, 768, 384], num_levels=3): super(YOLOFeat, self).__init__() self.feat_in_list = feat_in_list self.yolo_blocks = [] self.route_blocks = [] - for i in range(3): + self.num_levels = num_levels + for i in range(self.num_levels): + name = 'yolo_block.{}'.format(i) yolo_block = self.add_sublayer( - "yolo_det_block_%d" % (i), + name, YoloDetBlock( - feat_in_list[i], channel=512 // (2**i))) + feat_in_list[i], channel=512 // (2**i), name=name)) self.yolo_blocks.append(yolo_block) - if i < 2: + if i < self.num_levels - 1: + name = 'yolo_transition.{}'.format(i) route = self.add_sublayer( - "route_%d" % i, + name, ConvBNLayer( ch_in=512 // (2**i), ch_out=256 // (2**i), filter_size=1, stride=1, - padding=0)) + padding=0, + name=name)) self.route_blocks.append(route) - self.upsample = Upsample() - def forward(self, inputs): + def forward(self, body_feats): + assert len(body_feats) == self.num_levels + body_feats = body_feats[::-1] yolo_feats = [] - for i, block in enumerate(inputs['darknet_outs']): + for i, block in enumerate(body_feats): if i > 0: block = fluid.layers.concat(input=[route, block], axis=1) - route, tip = self.yolo_blocks[i](block) yolo_feats.append(tip) - if i < 2: + if i < self.num_levels - 1: route = self.route_blocks[i](route) - route = self.upsample(route) + route = fluid.layers.resize_nearest(route, scale=2.) - outs = {'yolo_feat': yolo_feats} - return outs + return yolo_feats @register class YOLOv3Head(Layer): - __shared__ = ['num_classes'] + __shared__ = ['num_classes', 'num_levels', 'use_fine_grained_loss'] __inject__ = ['yolo_feat'] - def __init__(self, yolo_feat, num_classes=80, anchor_per_position=3): + def __init__(self, + yolo_feat, + num_classes=80, + anchor_per_position=3, + num_levels=3, + use_fine_grained_loss=False, + ignore_thresh=0.7, + downsample=32, + label_smooth=True): super(YOLOv3Head, self).__init__() self.num_classes = num_classes self.anchor_per_position = anchor_per_position self.yolo_feat = yolo_feat - - self.yolo_outs = [] - for i in range(3): + self.num_levels = num_levels + self.use_fine_grained_loss = use_fine_grained_loss + self.ignore_thresh = ignore_thresh + self.downsample = downsample + self.label_smooth = label_smooth + self.yolo_out_list = [] + for i in range(num_levels): # TODO: optim here #num_filters = len(cfg.anchor_masks[i]) * (self.num_classes + 5) num_filters = self.anchor_per_position * (self.num_classes + 5) + name = 'yolo_output.{}'.format(i) yolo_out = self.add_sublayer( - "yolo_out_%d" % (i), + name, Conv2D( num_channels=1024 // (2**i), num_filters=num_filters, @@ -152,44 +137,38 @@ class YOLOv3Head(Layer): stride=1, padding=0, act=None, - param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02)), + param_attr=ParamAttr(name=name + '.conv.weights'), bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.)))) - self.yolo_outs.append(yolo_out) + name=name + '.conv.bias', regularizer=L2Decay(0.)))) + self.yolo_out_list.append(yolo_out) + + def forward(self, body_feats): + assert len(body_feats) == self.num_levels + yolo_feats = self.yolo_feat(body_feats) + yolo_head_out = [] + for i, feat in enumerate(yolo_feats): + yolo_out = self.yolo_out_list[i](feat) + yolo_head_out.append(yolo_out) + return yolo_head_out + + def loss(self, inputs, head_out, anchors, anchor_masks, mask_anchors): + if self.use_fine_grained_loss: + raise NotImplementedError - def forward(self, inputs): - outs = self.yolo_feat(inputs) - x = outs['yolo_feat'] - yolo_out_list = [] - for i, yolo_f in enumerate(x): - yolo_out = self.yolo_outs[i](yolo_f) - yolo_out_list.append(yolo_out) - outs.update({"yolo_outs": yolo_out_list}) - return outs - - def loss(self, inputs): - if callable(inputs['anchor_module']): - yolo_targets = inputs['anchor_module'].generate_anchors_target( - inputs) yolo_losses = [] - for i, out in enumerate(inputs['yolo_outs']): - # TODO: split yolov3_loss into small ops - # 1. compute target 2. loss + for i, out in enumerate(head_out): loss = fluid.layers.yolov3_loss( x=out, gt_box=inputs['gt_bbox'], gt_label=inputs['gt_class'], gt_score=inputs['gt_score'], - anchors=inputs['anchors'], - anchor_mask=inputs['anchor_masks'][i], + anchors=anchors, + anchor_mask=anchor_masks[i], class_num=self.num_classes, - ignore_thresh=yolo_targets['ignore_thresh'], - downsample_ratio=yolo_targets['downsample_ratio'] // 2**i, - use_label_smooth=yolo_targets['label_smooth'], + ignore_thresh=self.ignore_thresh, + downsample_ratio=self.downsample // 2**i, + use_label_smooth=self.label_smooth, name='yolo_loss_' + str(i)) loss = fluid.layers.reduce_mean(loss) yolo_losses.append(loss) - yolo_loss = sum(yolo_losses) - return yolo_loss + return {'loss': sum(yolo_losses)} diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 5d7a112939a9ddbde947f6aa78507b34edcc732a..7f2edbac1d894828847536e4ffd2ee2232c161ae 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -99,20 +99,16 @@ class AnchorGeneratorYOLO(object): self.anchors = anchors self.anchor_masks = anchor_masks - def __call__(self, yolo_outs): + def __call__(self): + anchor_num = len(self.anchors) mask_anchors = [] - for i, _ in enumerate(yolo_outs): + for i in range(len(self.anchor_masks)): mask_anchor = [] for m in self.anchor_masks[i]: - mask_anchor.append(self.anchors[2 * m]) - mask_anchor.append(self.anchors[2 * m + 1]) + assert m < anchor_num, "anchor mask index overflow" + mask_anchor.extend(self.anchors[2 * m:2 * m + 2]) mask_anchors.append(mask_anchor) - outs = { - "anchors": self.anchors, - "anchor_masks": self.anchor_masks, - "mask_anchors": mask_anchors - } - return outs + return self.anchors, self.anchor_masks, mask_anchors @register @@ -305,6 +301,7 @@ class RoIExtractor(object): self.canconical_level, self.canonical_size, rois_num=rois_num) + rois_feat_list = [] for lvl in range(self.start_level, self.end_level + 1): roi_feat = fluid.layers.roi_align( @@ -381,24 +378,19 @@ class MultiClassNMS(object): @register @serializable class YOLOBox(object): - __shared__ = ['num_classes'] - def __init__( self, - num_classes=80, conf_thresh=0.005, downsample_ratio=32, clip_bbox=True, ): - self.num_classes = num_classes self.conf_thresh = conf_thresh self.downsample_ratio = downsample_ratio self.clip_bbox = clip_bbox - def __call__(self, x, img_size, anchors, stage=0, name=None): - - outs = fluid.layers.yolo_box(x, img_size, anchors, self.num_classes, + def __call__(self, x, img_size, anchors, num_classes, stage=0): + outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes, self.conf_thresh, self.downsample_ratio // - 2**stage, self.clip_bbox, name) + 2**stage, self.clip_bbox) return outs diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index 2016cda13384385aeec88b667d7543090d66f24a..b2c73c88d1cdaa0d6c30032d9926a0e38142e921 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -82,7 +82,7 @@ class LinearWarmup(object): def __call__(self, base_lr): boundary = [] value = [] - for i in range(self.steps): + for i in range(self.steps + 1): alpha = i / self.steps factor = self.start_factor * (1 - alpha) + alpha lr = base_lr * factor diff --git a/ppdet/py_op/bbox.py b/ppdet/py_op/bbox.py index dc34a77cd2b3f09a76fb0ca1476458855ebcc301..33f0c8837658d73609a6e9d6be37de2eae399b51 100755 --- a/ppdet/py_op/bbox.py +++ b/ppdet/py_op/bbox.py @@ -190,7 +190,7 @@ def nms_with_decode(bboxes, rois_n = bboxes_v[start:end, :] # box rois_n = rois_n / im_info[i][2] # scale rois_n = delta2bbox(bbox_deltas_n, rois_n, variance_v) - rois_n = clip_bbox(rois_n, im_info[i][:2] / im_info[i][2]) + rois_n = clip_bbox(rois_n, np.round(im_info[i][:2] / im_info[i][2])) cls_boxes = [[] for _ in range(class_nums)] scores_n = bbox_probs_v[start:end, :] for j in range(1, class_nums): diff --git a/ppdet/py_op/post_process.py b/ppdet/py_op/post_process.py index c6bf354e47453fade8c9fc88dea42f92df90b478..2049bffd0df80e45d61eba35c3ce77395ac8ad42 100755 --- a/ppdet/py_op/post_process.py +++ b/ppdet/py_op/post_process.py @@ -136,7 +136,8 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): k = 0 for i in range(len(bbox_nums)): image_id = int(image_id[i][0]) - + if bboxes.shape == (1, 1): + continue det_nums = bbox_nums[i] for j in range(det_nums): dt = bboxes[k] diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index c2c41947f4cf1c76eeea4a72f5e9178e129058e9..419f5d6112db6390cf94106df194d1b532b1fd56 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -3,10 +3,12 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +import errno import os import time import re import numpy as np +import paddle import paddle.fluid as fluid from .download import get_weights_path @@ -88,9 +90,10 @@ def load_dygraph_ckpt(model, return model -def save_dygraph_ckpt(model, optimizer, save_dir): +def save_dygraph_ckpt(model, optimizer, save_dir, save_name): if not os.path.exists(save_dir): os.makedirs(save_dir) - fluid.dygraph.save_dygraph(model.state_dict(), save_dir) - fluid.dygraph.save_dygraph(optimizer.state_dict(), save_dir) + save_path = os.path.join(save_dir, save_name) + fluid.dygraph.save_dygraph(model.state_dict(), save_path) + fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path) print("Save checkpoint:", save_dir) diff --git a/tools/eval.py b/tools/eval.py index d437cd7cb5b3d3e466f66e3846a9b171ed3d1181..b39573d01873d921c798c98f9ab70b99de4f2483 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -20,6 +20,10 @@ from ppdet.utils.cli import ArgsParser from ppdet.utils.eval_utils import coco_eval_results from ppdet.data.reader import create_reader from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) def parse_args(): @@ -58,22 +62,26 @@ def run(FLAGS, cfg): # Run Eval outs_res = [] + start_time = time.time() + sample_num = 0 for iter_id, data in enumerate(eval_reader()): - start_time = time.time() - # forward model.eval() outs = model(data, cfg['EvalReader']['inputs_def']['fields'], 'infer') outs_res.append(outs) # log - cost_time = time.time() - start_time - print("Eval iter: {}, time: {}".format(iter_id, cost_time)) + sample_num += len(data) + if iter_id % 100 == 0: + logger.info("Eval iter: {}".format(iter_id)) + cost_time = time.time() - start_time + logger.info('Total sample number: {}, averge FPS: {}'.format( + sample_num, sample_num / cost_time)) # Metric coco_eval_results( outs_res, - include_mask=True if 'MaskHead' in cfg else False, + include_mask=True if getattr(cfg, 'MaskHead', None) else False, dataset=cfg['EvalReader']['dataset']) diff --git a/tools/train.py b/tools/train.py index 27f13323be3ab7d4f8b82a6d01f71a2a0fcc9ade..76a860909dd2e1c6a593e441e34d4103082cf910 100755 --- a/tools/train.py +++ b/tools/train.py @@ -113,7 +113,7 @@ def run(FLAGS, cfg): optimizer, cfg.pretrain_weights, ckpt_type=FLAGS.ckpt_type, - load_static_weights=cfg.load_static_weights) + load_static_weights=cfg.get('load_static_weights', False)) # Parallel Model if ParallelEnv().nranks > 1: @@ -177,8 +177,8 @@ def run(FLAGS, cfg): cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_name = str( iter_id) if iter_id != cfg.max_iters - 1 else "model_final" - save_dir = os.path.join(cfg.save_dir, cfg_name, save_name) - save_dygraph_ckpt(model, optimizer, save_dir) + save_dir = os.path.join(cfg.save_dir, cfg_name) + save_dygraph_ckpt(model, optimizer, save_dir, save_name) def main():