diff --git a/yolov3/darknet.py b/yolov3/darknet.py index 412457b25eb9d3a2d22df44de389ee144f831652..bacb039ccba57d7eb040d3b60f1ef5e31e28f355 100644 --- a/yolov3/darknet.py +++ b/yolov3/darknet.py @@ -12,14 +12,14 @@ #See the License for the specific language governing permissions and #limitations under the License. +import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay +from paddle.static import InputSpec from paddle.fluid.dygraph.nn import Conv2D, BatchNorm - -from paddle.incubate.hapi.model import Model -from paddle.incubate.hapi.download import get_weights_path_from_url +from paddle.utils.download import get_weights_path_from_url __all__ = ['DarkNet', 'darknet53'] @@ -131,7 +131,7 @@ class LayerWarp(fluid.dygraph.Layer): DarkNet_cfg = {53: ([1, 2, 8, 8, 4])} -class DarkNet(Model): +class DarkNet(fluid.dygraph.Layer): """DarkNet model from `"YOLOv3: An Incremental Improvement" `_ @@ -190,7 +190,8 @@ def _darknet(num_layers=53, input_channels=3, pretrained=True): weight_path = get_weights_path_from_url(*(pretrain_infos[num_layers])) assert weight_path.endswith('.pdparams'), \ "suffix of weight must be .pdparams" - model.load(weight_path[:-9]) + weight_dict, _ = fluid.load_dygraph(weight_path[:-9]) + model.set_dict(weight_dict) return model diff --git a/yolov3/infer.py b/yolov3/infer.py index a2640fec1f1bb1d9844a10df79910b8e62408091..347a9103ed7f1496d5e0794f5ae790e214e5497d 100644 --- a/yolov3/infer.py +++ b/yolov3/infer.py @@ -20,12 +20,11 @@ import argparse import numpy as np from PIL import Image +import paddle from paddle import fluid from paddle.fluid.optimizer import Momentum from paddle.io import DataLoader -from paddle.incubate.hapi.model import Model, Input, set_device - from modeling import yolov3_darknet53, YoloLoss from transforms import * from utils import print_arguments @@ -36,6 +35,7 @@ logger = logging.getLogger(__name__) IMAGE_MEAN = [0.485, 0.456, 0.406] IMAGE_STD = [0.229, 0.224, 0.225] +NUM_MAX_BOXES = 50 def get_save_image_name(output_dir, image_path): @@ -62,24 +62,18 @@ def load_labels(label_list, with_background=True): def main(): - device = set_device(FLAGS.device) - fluid.enable_dygraph(device) if FLAGS.dynamic else None - - inputs = [ - Input( - [None, 1], 'int64', name='img_id'), Input( - [None, 2], 'int32', name='img_shape'), Input( - [None, 3, None, None], 'float32', name='image') - ] + device = paddle.set_device(FLAGS.device) + paddle.disable_static(device) if FLAGS.dynamic else None cat2name = load_labels(FLAGS.label_list, with_background=False) model = yolov3_darknet53( num_classes=len(cat2name), + num_max_boxes=NUM_MAX_BOXES, model_mode='test', pretrained=FLAGS.weights is None) - model.prepare(inputs=inputs, device=FLAGS.device) + model.prepare() if FLAGS.weights is not None: model.load(FLAGS.weights, reset_optimizer=True) diff --git a/yolov3/main.py b/yolov3/main.py index 5d44967e30863bb627cc3c26fecd4473072196fe..a8f0c61133106f1b0d7a32ceacab12f7811017f8 100644 --- a/yolov3/main.py +++ b/yolov3/main.py @@ -21,13 +21,11 @@ import os import numpy as np +import paddle from paddle import fluid from paddle.fluid.optimizer import Momentum -from paddle.io import DataLoader - -from paddle.incubate.hapi.model import Model, Input, set_device -from paddle.incubate.hapi.distributed import DistributedBatchSampler -from paddle.incubate.hapi.vision.transforms import Compose, BatchCompose +from paddle.io import DataLoader, DistributedBatchSampler +from paddle.vision.transforms import Compose, BatchCompose from modeling import yolov3_darknet53, YoloLoss from coco import COCODataset @@ -61,22 +59,8 @@ def make_optimizer(step_per_epoch, parameter_list=None): def main(): - device = set_device(FLAGS.device) - fluid.enable_dygraph(device) if FLAGS.dynamic else None - - inputs = [ - Input( - [None, 1], 'int64', name='img_id'), Input( - [None, 2], 'int32', name='img_shape'), Input( - [None, 3, None, None], 'float32', name='image') - ] - - labels = [ - Input( - [None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'), Input( - [None, NUM_MAX_BOXES], 'int32', name='gt_label'), Input( - [None, NUM_MAX_BOXES], 'float32', name='gt_score') - ] + device = paddle.set_device(FLAGS.device) + paddle.disable_static(device) if FLAGS.dynamic else None if not FLAGS.eval_only: # training mode train_transform = Compose([ @@ -129,6 +113,7 @@ def main(): pretrained = FLAGS.eval_only and FLAGS.weights is None model = yolov3_darknet53( num_classes=dataset.num_classes, + num_max_boxes=NUM_MAX_BOXES, model_mode='eval' if FLAGS.eval_only else 'train', pretrained=pretrained) @@ -140,11 +125,7 @@ def main(): len(batch_sampler), parameter_list=model.parameters()) model.prepare( - optim, - YoloLoss(num_classes=dataset.num_classes), - inputs=inputs, - labels=labels, - device=FLAGS.device) + optimizer=optim, loss=YoloLoss(num_classes=dataset.num_classes)) # NOTE: we implement COCO metric of YOLOv3 model here, separately # from 'prepare' and 'fit' framework for follwing reason: diff --git a/yolov3/modeling.py b/yolov3/modeling.py index 4bebe2359565fb697c074455e577b4ba516e2d36..9cba02e4d9ad09180c0d5c5e6b75d0efa4401102 100644 --- a/yolov3/modeling.py +++ b/yolov3/modeling.py @@ -15,14 +15,15 @@ from __future__ import division from __future__ import print_function +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.nn import Conv2D, BatchNorm from paddle.fluid.param_attr import ParamAttr from paddle.fluid.regularizer import L2Decay -from paddle.incubate.hapi.model import Model -from paddle.incubate.hapi.loss import Loss -from paddle.incubate.hapi.download import get_weights_path_from_url +from paddle.static import InputSpec +from paddle.utils.download import get_weights_path_from_url + from darknet import darknet53 __all__ = ['YoloLoss', 'YOLOv3', 'yolov3_darknet53'] @@ -125,7 +126,7 @@ class YoloDetectionBlock(fluid.dygraph.Layer): return route, tip -class YOLOv3(Model): +class YOLOv3(fluid.dygraph.Layer): """YOLOv3 model from `"YOLOv3: An Incremental Improvement" `_ @@ -194,25 +195,13 @@ class YOLOv3(Model): act='leaky_relu')) self.route_blocks.append(route) - def extract_feats(self, inputs): - out = self.backbone.conv0(inputs) - out = self.backbone.downsample0(out) - blocks = [] - for i, conv_block_i in enumerate( - self.backbone.darknet53_conv_block_list): - out = conv_block_i(out) - blocks.append(out) - if i < len(self.backbone.stages) - 1: - out = self.backbone.downsample_list[i](out) - return blocks[-1:-4:-1] - def forward(self, img_id, img_shape, inputs): outputs = [] boxes = [] scores = [] downsample = 32 - feats = self.extract_feats(inputs) + feats = self.backbone(inputs) route = None for idx, feat in enumerate(feats): if idx > 0: @@ -267,7 +256,7 @@ class YOLOv3(Model): return outputs + preds -class YoloLoss(Loss): +class YoloLoss(fluid.dygraph.Layer): def __init__(self, num_classes=80, num_max_boxes=50): super(YoloLoss, self).__init__() self.num_classes = num_classes @@ -279,11 +268,16 @@ class YoloLoss(Loss): ] self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - def forward(self, outputs, labels): + def forward(self, *inputs): downsample = 32 - gt_box, gt_label, gt_score = labels losses = [] + if len(inputs) == 6: + output1, output2, output3, gt_box, gt_label, gt_score = inputs + elif len(inputs) == 8: + output1, output2, output3, img_id, bbox, gt_box, gt_label, gt_score = inputs + + outputs = [output1, output2, output3] for idx, out in enumerate(outputs): if idx == 3: break # debug anchor_mask = self.anchor_masks[idx] @@ -306,9 +300,23 @@ class YoloLoss(Loss): def _yolov3_darknet(num_layers=53, num_classes=80, + num_max_boxes=50, model_mode='train', pretrained=True): - model = YOLOv3(num_classes, model_mode) + inputs = [ + InputSpec( + [None, 1], 'int64', name='img_id'), InputSpec( + [None, 2], 'int32', name='img_shape'), InputSpec( + [None, 3, None, None], 'float32', name='image') + ] + labels = [ + InputSpec( + [None, num_max_boxes, 4], 'float32', name='gt_bbox'), InputSpec( + [None, num_max_boxes], 'int32', name='gt_label'), InputSpec( + [None, num_max_boxes], 'float32', name='gt_score') + ] + net = YOLOv3(num_classes, model_mode) + model = paddle.Model(net, inputs, labels) if pretrained: assert num_layers in pretrain_infos.keys(), \ "YOLOv3-DarkNet{} do not have pretrained weights now, " \ @@ -320,11 +328,15 @@ def _yolov3_darknet(num_layers=53, return model -def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True): +def yolov3_darknet53(num_classes=80, + num_max_boxes=50, + model_mode='train', + pretrained=True): """YOLOv3 model with 53-layer DarkNet as backbone Args: num_classes (int): class number, default 80. + num_classes (int): max bbox number in a image, default 50. model_mode (str): 'train', 'eval', 'test' mode, network structure will be diffrent in the output layer and data, in 'train' mode, no output layer append, in 'eval' and 'test', output feature @@ -334,4 +346,5 @@ def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True): pretrained (bool): If True, returns a model with pre-trained model on COCO, default True """ - return _yolov3_darknet(53, num_classes, model_mode, pretrained) + return _yolov3_darknet(53, num_classes, num_max_boxes, model_mode, + pretrained)