From dab4becc9084aab487c0a0ab6a600f52c6fdd1b7 Mon Sep 17 00:00:00 2001 From: FDInSky <48318485+FDInSky@users.noreply.github.com> Date: Mon, 13 Jul 2020 11:08:34 +0800 Subject: [PATCH] add train&eval (#1042) add train&eval add reader padding add new config --- configs/cascade_rcnn_r50_1x.yml | 125 +++++++++ configs/faster_rcnn_r50_1x.yml | 107 +++++++ configs/faster_reader.yml | 95 +++++++ configs/mask_rcnn_r50_1x.yml | 127 +++++++++ configs/mask_reader.yml | 101 +++++++ configs/yolov3_darknet.yml | 75 +++++ configs/yolov3_reader.yml | 111 ++++++++ ppdet/core/workspace.py | 23 +- ppdet/data/transform/batch_operators.py | 51 +++- ppdet/optimizer.py | 161 ++--------- ppdet/utils/checkpoint.py | 356 +++++------------------- ppdet/utils/data_structure.py | 2 +- ppdet/utils/eval_utils.py | 283 ++++--------------- tools/__init__.py | 0 tools/eval.py | 93 +++++++ tools/train.py | 198 +++++++++++++ 16 files changed, 1247 insertions(+), 661 deletions(-) create mode 100644 configs/cascade_rcnn_r50_1x.yml create mode 100644 configs/faster_rcnn_r50_1x.yml create mode 100644 configs/faster_reader.yml create mode 100644 configs/mask_rcnn_r50_1x.yml create mode 100644 configs/mask_reader.yml create mode 100644 configs/yolov3_darknet.yml create mode 100644 configs/yolov3_reader.yml create mode 100644 tools/__init__.py create mode 100755 tools/eval.py create mode 100755 tools/train.py diff --git a/configs/cascade_rcnn_r50_1x.yml b/configs/cascade_rcnn_r50_1x.yml new file mode 100644 index 000000000..e6e1f7d6b --- /dev/null +++ b/configs/cascade_rcnn_r50_1x.yml @@ -0,0 +1,125 @@ +architecture: CascadeRCNN +use_gpu: true +max_iters: 180000 +log_smooth_window: 50 +save_dir: output +snapshot_iter: 10000 +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams +metric: COCO +weights: output/cascade_rcnn_r50_1x/model_final +num_classes: 81 +num_stages: 3 +open_debug: False + +# Model Achitecture +CascadeRCNN: + # model anchor info flow + anchor: AnchorRPN + proposal: Proposal + mask: Mask + # model feat info flow + backbone: ResNet + rpn_head: RPNHead + bbox_head: BBoxHead + mask_head: MaskHead + +ResNet: + norm_type: 'affine' + depth: 50 + freeze_at: 'res2' + +RPNHead: + rpn_feat: + name: RPNFeat + feat_in: 1024 + feat_out: 1024 + anchor_per_position: 15 + +BBoxHead: + bbox_feat: + name: BBoxFeat + feat_in: 1024 + feat_out: 512 + roi_extractor: + resolution: 14 + sampling_ratio: 0 + spatial_scale: 0.0625 + extractor_type: 'RoIAlign' + +MaskHead: + mask_feat: + name: MaskFeat + feat_in: 2048 + feat_out: 256 + feat_in: 256 + resolution: 14 + +AnchorRPN: + anchor_generator: + name: AnchorGeneratorRPN + anchor_sizes: [32, 64, 128, 256, 512] + aspect_ratios: [0.5, 1.0, 2.0] + stride: [16.0, 16.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_target_generator: + name: AnchorTargetGeneratorRPN + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + straddle_thresh: 0.0 + +Proposal: + proposal_generator: + name: ProposalGenerator + min_size: 0.0 + nms_thresh: 0.7 + train_pre_nms_top_n: 2000 + train_post_nms_top_n: 2000 + infer_pre_nms_top_n: 2000 + infer_post_nms_top_n: 2000 + return_rois_num: True + proposal_target_generator: + name: ProposalTargetGenerator + batch_size_per_im: 512 + bbox_reg_weights: [[0.1, 0.1, 0.2, 0.2],[0.05, 0.05, 0.1, 0.1],[0.333333, 0.333333, 0.666666, 0.666666]] + bg_thresh_hi: [0.5, 0.6, 0.7] + bg_thresh_lo: [0.0, 0.0, 0.0] + fg_thresh: [0.5, 0.6, 0.7] + fg_fraction: 0.25 + bbox_post_process: # used in infer + name: BBoxPostProcess + # decode -> clip -> nms + decode_clip_nms: + name: DecodeClipNms + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 + +Mask: + mask_target_generator: + name: MaskTargetGenerator + resolution: 14 + mask_post_process: + name: MaskPostProcess + +# Train +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: 'mask_reader.yml' diff --git a/configs/faster_rcnn_r50_1x.yml b/configs/faster_rcnn_r50_1x.yml new file mode 100644 index 000000000..d36b45abd --- /dev/null +++ b/configs/faster_rcnn_r50_1x.yml @@ -0,0 +1,107 @@ +architecture: FasterRCNN +use_gpu: true +max_iters: 180000 +log_smooth_window: 50 +save_dir: output +snapshot_iter: 10000 +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams +metric: COCO +weights: output/faster_rcnn_r50_1x/model_final +num_classes: 81 +open_debug: False + +# Model Achitecture +FasterRCNN: + # model anchor info flow + anchor: AnchorRPN + proposal: Proposal + # model feat info flow + backbone: ResNet + rpn_head: RPNHead + bbox_head: BBoxHead + +ResNet: + depth: 50 + norm_type: 'affine' + freeze_at: 'res2' + +RPNHead: + rpn_feat: + name: RPNFeat + feat_in: 1024 + feat_out: 1024 + anchor_per_position: 15 + +BBoxHead: + bbox_feat: + name: BBoxFeat + roi_extractor: + name: RoIExtractor + resolution: 14 + sampling_ratio: 0 + spatial_scale: 0.0625 + extractor_type: 'RoIAlign' + feat_out: 512 + +AnchorRPN: + anchor_generator: + name: AnchorGeneratorRPN + anchor_sizes: [32, 64, 128, 256, 512] + aspect_ratios: [0.5, 1.0, 2.0] + stride: [16.0, 16.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_target_generator: + name: AnchorTargetGeneratorRPN + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + straddle_thresh: 0.0 + +Proposal: + proposal_generator: + name: ProposalGenerator + min_size: 0.0 + nms_thresh: 0.7 + train_pre_nms_top_n: 12000 + train_post_nms_top_n: 2000 + infer_pre_nms_top_n: 12000 # used in infer + infer_post_nms_top_n: 2000 # used in infer + return_rois_num: True + proposal_target_generator: + name: ProposalTargetGenerator + batch_size_per_im: 512 + bbox_reg_weights: [[0.1, 0.1, 0.2, 0.2],] + bg_thresh_hi: [0.5,] + bg_thresh_lo: [0.0,] + fg_thresh: [0.5,] + fg_fraction: 0.25 + bbox_post_process: # used in infer + name: BBoxPostProcess + # decode -> clip -> nms + decode_clip_nms: + name: DecodeClipNms + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 + +# Train +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: 'faster_reader.yml' diff --git a/configs/faster_reader.yml b/configs/faster_reader.yml new file mode 100644 index 000000000..e31610685 --- /dev/null +++ b/configs/faster_reader.yml @@ -0,0 +1,95 @@ +TrainReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + sample_transforms: + - !DecodeImage + to_rgb: True + - !RandomFlipImage + prob: 0.5 + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + target_size: 800 + max_size: 1333 + interp: 1 + use_cv2: true + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !PadBatch + pad_to_stride: 0 + use_padded_im_info: False + pad_gt: true + batch_size: 1 + shuffle: true + worker_num: 2 + use_process: false + +EvalReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + sample_transforms: + - !DecodeImage + to_rgb: true + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + pad_gt: True + batch_size: 2 + shuffle: false + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_size: 1 + shuffle: false diff --git a/configs/mask_rcnn_r50_1x.yml b/configs/mask_rcnn_r50_1x.yml new file mode 100644 index 000000000..7f089140a --- /dev/null +++ b/configs/mask_rcnn_r50_1x.yml @@ -0,0 +1,127 @@ +architecture: MaskRCNN +use_gpu: true +max_iters: 180000 +log_smooth_window: 50 +save_dir: output +snapshot_iter: 10000 +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams +metric: COCO +weights: output/mask_rcnn_r50_1x/model_final +num_classes: 81 +open_debug: False + +# Model Achitecture +MaskRCNN: + # model anchor info flow + anchor: AnchorRPN + proposal: Proposal + mask: Mask + # model feat info flow + backbone: ResNet + rpn_head: RPNHead + bbox_head: BBoxHead + mask_head: MaskHead + +ResNet: + norm_type: 'affine' + depth: 50 + freeze_at: 'res2' + +RPNHead: + rpn_feat: + name: RPNFeat + feat_in: 1024 + feat_out: 1024 + anchor_per_position: 15 + +BBoxHead: + bbox_feat: + name: BBoxFeat + roi_extractor: + name: RoIExtractor + resolution: 14 + sampling_ratio: 0 + spatial_scale: 0.0625 + extractor_type: 'RoIAlign' + feat_in: 1024 + feat_out: 512 + +MaskHead: + mask_feat: + name: MaskFeat + feat_in: 2048 + feat_out: 256 + mask_stages: 1 + feat_in: 256 + resolution: 14 + mask_stages: 1 + +AnchorRPN: + anchor_generator: + name: AnchorGeneratorRPN + anchor_sizes: [32, 64, 128, 256, 512] + aspect_ratios: [0.5, 1.0, 2.0] + stride: [16.0, 16.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_target_generator: + name: AnchorTargetGeneratorRPN + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + straddle_thresh: 0.0 + +Proposal: + proposal_generator: + name: ProposalGenerator + min_size: 0.0 + nms_thresh: 0.7 + train_pre_nms_top_n: 12000 + train_post_nms_top_n: 2000 + infer_pre_nms_top_n: 12000 + infer_post_nms_top_n: 2000 + return_rois_num: True + proposal_target_generator: + name: ProposalTargetGenerator + batch_size_per_im: 512 + bbox_reg_weights: [[0.1, 0.1, 0.2, 0.2],] + bg_thresh_hi: [0.5,] + bg_thresh_lo: [0.0,] + fg_thresh: [0.5,] + fg_fraction: 0.25 + bbox_post_process: # used in infer + name: BBoxPostProcess + # decode -> clip -> nms + decode_clip_nms: + name: DecodeClipNms + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 + +Mask: + mask_target_generator: + name: MaskTargetGenerator + resolution: 14 + mask_post_process: + name: MaskPostProcess + +# Train +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: 'mask_reader.yml' diff --git a/configs/mask_reader.yml b/configs/mask_reader.yml new file mode 100644 index 000000000..5280abac3 --- /dev/null +++ b/configs/mask_reader.yml @@ -0,0 +1,101 @@ +TrainReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_mask'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + sample_transforms: + - !DecodeImage + to_rgb: true + - !RandomFlipImage + prob: 0.5 + is_mask_flip: true + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + target_size: 512 + max_size: 512 + interp: 1 + use_cv2: true + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + pad_gt: True + batch_size: 1 + shuffle: true + worker_num: 2 + drop_last: false + use_process: false + +EvalReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + # for voc + #fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + sample_transforms: + - !DecodeImage + to_rgb: true + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + pad_gt: True + batch_size: 1 + shuffle: false + drop_last: false + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/yolov3_darknet.yml b/configs/yolov3_darknet.yml new file mode 100644 index 000000000..7a1215def --- /dev/null +++ b/configs/yolov3_darknet.yml @@ -0,0 +1,75 @@ +architecture: YOLOv3 +use_gpu: true +max_iters: 500000 +log_smooth_window: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddlemodels.bj.bcebos.com/yolo/darknet53.pdparams +weights: output/yolov3_darknet/model_final +num_classes: 80 +use_fine_grained_loss: false +open_debug: False + +YOLOv3: + anchor: AnchorYOLO + backbone: DarkNet + yolo_head: YOLOv3Head + +DarkNet: + depth: 53 + +YOLOv3Head: + yolo_feat: + name: YOLOFeat + feat_in_list: [1024, 768, 384] + anchor_per_position: 3 + +AnchorYOLO: + anchor_generator: + name: AnchorGeneratorYOLO + anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchor_target_generator: + name: AnchorTargetGeneratorYOLO + ignore_thresh: 0.7 + downsample_ratio: 32 + label_smooth: true + anchor_post_process: + name: BBoxPostProcessYOLO + # decode -> clip + yolo_box: + name: YOLOBox + conf_thresh: 0.005 + downsample_ratio: 32 + clip_bbox: True + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.01 + nms_threshold: 0.45 + nms_top_k: 1000 + normalized: false + background_label: -1 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 400000 + - 450000 + - !LinearWarmup + start_factor: 0. + steps: 4000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 + +_READER_: 'yolov3_reader.yml' diff --git a/configs/yolov3_reader.yml b/configs/yolov3_reader.yml new file mode 100644 index 000000000..2a8463f1e --- /dev/null +++ b/configs/yolov3_reader.yml @@ -0,0 +1,111 @@ +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class', 'gt_score'] + num_max_boxes: 50 + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + with_mixup: True + - !MixupImage + alpha: 1.5 + beta: 1.5 + - !ColorDistort {} + - !RandomExpand + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop {} + - !RandomFlipImage + is_normalized: false + - !NormalizeBox {} + - !PadBox + num_max_boxes: 50 + - !BboxXYXY2XYWH {} + batch_transforms: + - !RandomShape + sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + random_inter: True + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + # Gt2YoloTarget is only used when use_fine_grained_loss set as true, + # this operator will be deleted automatically if use_fine_grained_loss + # is set as false + - !Gt2YoloTarget + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + anchors: [[10, 13], [16, 30], [33, 23], + [30, 61], [62, 45], [59, 119], + [116, 90], [156, 198], [373, 326]] + downsample_ratios: [32, 16, 8] + batch_size: 8 + shuffle: true + mixup_epoch: 250 + drop_last: true + worker_num: 8 + bufsize: 16 + use_process: true + + +EvalReader: + inputs_def: + fields: ['image', 'im_size', 'im_id'] + num_max_boxes: 50 + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !PadBox + num_max_boxes: 50 + - !Permute + to_bgr: false + channel_first: True + batch_size: 8 + drop_empty: false + worker_num: 8 + bufsize: 16 + +TestReader: + inputs_def: + image_shape: [3, 608, 608] + fields: ['image', 'im_size', 'im_id'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !ResizeImage + target_size: 608 + interp: 2 + - !NormalizeImage + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + is_scale: True + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index cbcdcefc4..b151b15aa 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -199,13 +199,13 @@ def create(cls_or_name, **kwargs): config.update(kwargs) config.validate() cls = getattr(config.pymodule, name) - kwargs = {} kwargs.update(global_config[name]) # parse `shared` annoation of registered modules if getattr(config, 'shared', None): for k in config.shared: + target_key = config[k] shared_conf = config.schema[k].default assert isinstance(shared_conf, SharedConfig) @@ -225,9 +225,22 @@ def create(cls_or_name, **kwargs): # optional dependency if target_key is None: continue - # also accept dictionaries and serialized objects + if isinstance(target_key, dict) or hasattr(target_key, '__dict__'): - continue + if 'name' not in target_key.keys(): + continue + inject_name = str(target_key['name']) + if inject_name not in global_config: + raise ValueError( + "Missing injection name {} and check it's name in cfg file". + format(k)) + target = global_config[inject_name] + for i, v in target_key.items(): + if i == 'name': + continue + target[i] = v + if isinstance(target, SchemaDict): + kwargs[k] = create(inject_name) elif isinstance(target_key, str): if target_key not in global_config: raise ValueError("Missing injection config:", target_key) @@ -235,10 +248,10 @@ def create(cls_or_name, **kwargs): if isinstance(target, SchemaDict): kwargs[k] = create(target_key) elif hasattr(target, '__dict__'): # serialized object - kwargs[k] = target + kwargs[k] = new_dict else: raise ValueError("Unsupported injection type:", target_key) # prevent modification of global config values of reference types # (e.g., list, dict) from within the created module instances - kwargs = copy.deepcopy(kwargs) + #kwargs = copy.deepcopy(kwargs) return cls(**kwargs) diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 8da0b1e35..068614ab5 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -47,10 +47,11 @@ class PadBatch(BaseOperator): height and width is divisible by `pad_to_stride`. """ - def __init__(self, pad_to_stride=0, use_padded_im_info=True): + def __init__(self, pad_to_stride=0, use_padded_im_info=True, pad_gt=False): super(PadBatch, self).__init__() self.pad_to_stride = pad_to_stride self.use_padded_im_info = use_padded_im_info + self.pad_gt = pad_gt def __call__(self, samples, context=None): """ @@ -60,9 +61,9 @@ class PadBatch(BaseOperator): coarsest_stride = self.pad_to_stride if coarsest_stride == 0: return samples + max_shape = np.array([data['image'].shape for data in samples]).max( axis=0) - if coarsest_stride > 0: max_shape[1] = int( np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) @@ -79,6 +80,52 @@ class PadBatch(BaseOperator): data['image'] = padding_im if self.use_padded_im_info: data['im_info'][:2] = max_shape[1:3] + + if self.pad_gt: + gt_num = [] + if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + pad_mask = True + else: + pad_mask = False + + if pad_mask: + poly_num = [] + poly_part_num = [] + point_num = [] + for data in samples: + gt_num.append(data['gt_bbox'].shape[0]) + if pad_mask: + poly_num.append(len(data['gt_poly'])) + for poly in data['gt_poly']: + poly_part_num.append(int(len(poly))) + for p_p in poly: + point_num.append(int(len(p_p) / 2)) + gt_num_max = max(gt_num) + gt_box_data = np.zeros([gt_num_max, 4]) + gt_class_data = np.zeros([gt_num_max]) + is_crowd_data = np.ones([gt_num_max]) + + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2]) + + for i, data in enumerate(samples): + gt_num = data['gt_bbox'].shape[0] + gt_box_data[0:gt_num, :] = data['gt_bbox'] + gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) + is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd']) + if pad_mask: + for j, poly in enumerate(data['gt_poly']): + for k, p_p in enumerate(poly): + pp_np = np.array(p_p).reshape(-1, 2) + gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np + data['gt_poly'] = gt_masks_data + data['gt_bbox'] = gt_box_data + data['gt_class'] = gt_class_data + data['is_crowd_data'] = is_crowd_data return samples diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index b3506e55a..2016cda13 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -43,8 +43,7 @@ class PiecewiseDecay(object): milestones (list): steps at which to decay learning rate """ - def __init__(self, gamma=[0.1, 0.1], milestones=[60000, 80000], - values=None): + def __init__(self, gamma=[0.1, 0.01], milestones=[60000, 80000]): super(PiecewiseDecay, self).__init__() if type(gamma) is not list: self.gamma = [] @@ -53,126 +52,16 @@ class PiecewiseDecay(object): else: self.gamma = gamma self.milestones = milestones - self.values = values - def __call__(self, base_lr=None, learning_rate=None): - if self.values is not None: - return fluid.layers.piecewise_decay(self.milestones, self.values) - assert base_lr is not None, "either base LR or values should be provided" - values = [base_lr] - for g in self.gamma: - new_lr = base_lr * g - values.append(new_lr) - return fluid.layers.piecewise_decay(self.milestones, values) + def __call__(self, base_lr=None, boundary=None, value=None): + if boundary is not None: + boundary.extend(self.milestones) + if value is not None: + for i in self.gamma: + value.append(base_lr * i) -@serializable -class PolynomialDecay(object): - """ - Applies polynomial decay to the initial learning rate. - Args: - max_iter (int): The learning rate decay steps. - end_lr (float): End learning rate. - power (float): Polynomial attenuation coefficient - """ - - def __init__(self, max_iter=180000, end_lr=0.0001, power=1.0): - super(PolynomialDecay).__init__() - self.max_iter = max_iter - self.end_lr = end_lr - self.power = power - - def __call__(self, base_lr=None, learning_rate=None): - assert base_lr is not None, "either base LR or values should be provided" - lr = fluid.layers.polynomial_decay(base_lr, self.max_iter, self.end_lr, - self.power) - return lr - - -@serializable -class ExponentialDecay(object): - """ - Applies exponential decay to the learning rate. - Args: - max_iter (int): The learning rate decay steps. - decay_rate (float): The learning rate decay rate. - """ - - def __init__(self, max_iter, decay_rate): - super(ExponentialDecay).__init__() - self.max_iter = max_iter - self.decay_rate = decay_rate - - def __call__(self, base_lr=None, learning_rate=None): - assert base_lr is not None, "either base LR or values should be provided" - lr = fluid.layers.exponential_decay(base_lr, self.max_iter, - self.decay_rate) - return lr - - -@serializable -class CosineDecay(object): - """ - Cosine learning rate decay - - Args: - max_iters (float): max iterations for the training process. - if you commbine cosine decay with warmup, it is recommended that - the max_iter is much larger than the warmup iter - """ - - def __init__(self, max_iters=180000): - self.max_iters = max_iters - - def __call__(self, base_lr=None, learning_rate=None): - assert base_lr is not None, "either base LR or values should be provided" - lr = fluid.layers.cosine_decay(base_lr, 1, self.max_iters) - return lr - - -@serializable -class CosineDecayWithSkip(object): - """ - Cosine decay, with explicit support for warm up - - Args: - total_steps (int): total steps over which to apply the decay - skip_steps (int): skip some steps at the beginning, e.g., warm up - """ - - def __init__(self, total_steps, skip_steps=None): - super(CosineDecayWithSkip, self).__init__() - assert (not skip_steps or skip_steps > 0), \ - "skip steps must be greater than zero" - assert total_steps > 0, "total step must be greater than zero" - assert (not skip_steps or skip_steps < total_steps), \ - "skip steps must be smaller than total steps" - self.total_steps = total_steps - self.skip_steps = skip_steps - - def __call__(self, base_lr=None, learning_rate=None): - steps = _decay_step_counter() - total = self.total_steps - if self.skip_steps is not None: - total -= self.skip_steps - - lr = fluid.layers.tensor.create_global_var( - shape=[1], - value=base_lr, - dtype='float32', - persistable=True, - name="learning_rate") - - def decay(): - cos_lr = base_lr * .5 * (cos(steps * (math.pi / total)) + 1) - fluid.layers.tensor.assign(input=cos_lr, output=lr) - - if self.skip_steps is None: - decay() - else: - skipped = steps >= self.skip_steps - fluid.layers.cond(skipped, decay) - return lr + return fluid.dygraph.PiecewiseDecay(boundary, value, begin=0, step=1) @serializable @@ -190,14 +79,17 @@ class LinearWarmup(object): self.steps = steps self.start_factor = start_factor - def __call__(self, base_lr, learning_rate): - start_lr = base_lr * self.start_factor - - return fluid.layers.linear_lr_warmup( - learning_rate=learning_rate, - warmup_steps=self.steps, - start_lr=start_lr, - end_lr=base_lr) + def __call__(self, base_lr): + boundary = [] + value = [] + for i in range(self.steps): + alpha = i / self.steps + factor = self.start_factor * (1 - alpha) + alpha + lr = base_lr * factor + value.append(lr) + if i > 0: + boundary.append(i) + return boundary, value @register @@ -219,10 +111,12 @@ class LearningRate(object): self.schedulers = schedulers def __call__(self): - lr = None - for sched in self.schedulers: - lr = sched(self.base_lr, lr) - return lr + # TODO: split warmup & decay + # warmup + boundary, value = self.schedulers[1](self.base_lr) + # decay + decay_lr = self.schedulers[0](self.base_lr, boundary, value) + return decay_lr @register @@ -246,21 +140,24 @@ class OptimizerBuilder(): self.regularizer = regularizer self.optimizer = optimizer - def __call__(self, learning_rate): + def __call__(self, learning_rate, params=None): if self.clip_grad_by_norm is not None: fluid.clip.set_gradient_clip( clip=fluid.clip.GradientClipByGlobalNorm( clip_norm=self.clip_grad_by_norm)) + if self.regularizer: reg_type = self.regularizer['type'] + 'Decay' reg_factor = self.regularizer['factor'] regularization = getattr(regularizer, reg_type)(reg_factor) else: regularization = None + optim_args = self.optimizer.copy() optim_type = optim_args['type'] del optim_args['type'] op = getattr(optimizer, optim_type) return op(learning_rate=learning_rate, + parameter_list=params, regularization=regularization, **optim_args) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 42fe8194d..3ee6c328a 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -1,303 +1,87 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import errno import os -import shutil -import tempfile import time -import numpy as np import re +import numpy as np import paddle.fluid as fluid - from .download import get_weights_path -import logging -logger = logging.getLogger(__name__) - -__all__ = [ - 'load_checkpoint', - 'load_and_fusebn', - 'load_params', - 'save', -] - - -def is_url(path): - """ - Whether path is URL. - Args: - path (string): URL string or not. - """ - return path.startswith('http://') or path.startswith('https://') - -def _get_weight_path(path): - env = os.environ - if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: - trainer_id = int(env['PADDLE_TRAINER_ID']) - num_trainers = int(env['PADDLE_TRAINERS_NUM']) - if num_trainers <= 1: - path = get_weights_path(path) +def get_ckpt_path(path): + if path.startswith('http://') or path.startswith('https://'): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + path = get_weights_path(path) + else: + from ppdet.utils.download import map_path, WEIGHTS_HOME + weight_path = map_path(path, WEIGHTS_HOME) + lock_path = weight_path + '.lock' + if not os.path.exists(weight_path): + try: + os.makedirs(os.path.dirname(weight_path)) + except OSError as e: + if e.errno != errno.EEXIST: + raise + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if trainer_id == 0: + get_weights_path(path) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + path = weight_path else: - from ppdet.utils.download import map_path, WEIGHTS_HOME - weight_path = map_path(path, WEIGHTS_HOME) - lock_path = weight_path + '.lock' - if not os.path.exists(weight_path): - try: - os.makedirs(os.path.dirname(weight_path)) - except OSError as e: - if e.errno != errno.EEXIST: - raise - with open(lock_path, 'w'): # touch - os.utime(lock_path, None) - if trainer_id == 0: - get_weights_path(path) - os.remove(lock_path) - else: - while os.path.exists(lock_path): - time.sleep(1) - path = weight_path - else: - path = get_weights_path(path) - return path - - -def _load_state(path): - if os.path.exists(path + '.pdopt'): - # XXX another hack to ignore the optimizer state - tmp = tempfile.mkdtemp() - dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) - shutil.copy(path + '.pdparams', dst + '.pdparams') - state = fluid.io.load_program_state(dst) - shutil.rmtree(tmp) - else: - state = fluid.io.load_program_state(path) - return state - + path = get_weights_path(path) -def _strip_postfix(path): - path, ext = os.path.splitext(path) - assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ - "Unknown postfix {} from weights".format(ext) return path -def load_params(exe, prog, path, ignore_params=[]): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. - path (string): URL string or loca model path. - ignore_params (list): ignore variable to load when finetuning. - It can be specified by finetune_exclude_pretrained_params - and the usage can refer to docs/advanced_tutorials/TRANSFER_LEARNING.md - """ - - if is_url(path): - path = _get_weight_path(path) - - path = _strip_postfix(path) - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - - logger.debug('Loading parameters from {}...'.format(path)) - - ignore_set = set() - state = _load_state(path) - - # ignore the parameter which mismatch the shape - # between the model and pretrain weight. - all_var_shape = {} - for block in prog.blocks: - for param in block.all_parameters(): - all_var_shape[param.name] = param.shape - ignore_set.update([ - name for name, shape in all_var_shape.items() - if name in state and shape != state[name].shape - ]) - - if ignore_params: - all_var_names = [var.name for var in prog.list_vars()] - ignore_list = filter( - lambda var: any([re.match(name, var) for name in ignore_params]), - all_var_names) - ignore_set.update(list(ignore_list)) - - if len(ignore_set) > 0: - for k in ignore_set: - if k in state: - logger.warning('variable {} not used'.format(k)) - del state[k] - fluid.io.set_program_state(prog, state) - - -def load_checkpoint(exe, prog, path): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. - path (string): URL string or loca model path. - """ - if is_url(path): - path = _get_weight_path(path) - - path = _strip_postfix(path) - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - fluid.load(prog, path, executor=exe) - - -def global_step(scope=None): - """ - Load global step in scope. - Args: - scope (fluid.Scope): load global step from which scope. If None, - from default global_scope(). - - Returns: - global step: int. - """ - if scope is None: - scope = fluid.global_scope() - v = scope.find_var('@LR_DECAY_COUNTER@') - step = np.array(v.get_tensor())[0] if v else 0 - return step - - -def save(exe, prog, path): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): save weight from which Program object. - path (string): the path to save model. - """ - if os.path.isdir(path): - shutil.rmtree(path) - logger.info('Save model to {}.'.format(path)) - fluid.save(prog, path) - - -def load_and_fusebn(exe, prog, path): - """ - Fuse params of batch norm to scale and bias. - - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): save weight from which Program object. - path (string): the path to save model. - """ - logger.debug('Load model and fuse batch norm if have from {}...'.format( - path)) - - if is_url(path): - path = _get_weight_path(path) - - if not os.path.exists(path): - raise ValueError("Model path {} does not exists.".format(path)) - - # Since the program uses affine-channel, there is no running mean and var - # in the program, here append running mean and var. - # NOTE, the params of batch norm should be like: - # x_scale - # x_offset - # x_mean - # x_variance - # x is any prefix - mean_variances = set() - bn_vars = [] - state = _load_state(path) - - def check_mean_and_bias(prefix): - m = prefix + 'mean' - v = prefix + 'variance' - return v in state and m in state - - has_mean_bias = True - - with fluid.program_guard(prog, fluid.Program()): - for block in prog.blocks: - ops = list(block.ops) - if not has_mean_bias: - break - for op in ops: - if op.type == 'affine_channel': - # remove 'scale' as prefix - scale_name = op.input('Scale')[0] # _scale - bias_name = op.input('Bias')[0] # _offset - prefix = scale_name[:-5] - mean_name = prefix + 'mean' - variance_name = prefix + 'variance' - if not check_mean_and_bias(prefix): - has_mean_bias = False - break - - bias = block.var(bias_name) - - mean_vb = block.create_var( - name=mean_name, - type=bias.type, - shape=bias.shape, - dtype=bias.dtype) - variance_vb = block.create_var( - name=variance_name, - type=bias.type, - shape=bias.shape, - dtype=bias.dtype) - - mean_variances.add(mean_vb) - mean_variances.add(variance_vb) - - bn_vars.append( - [scale_name, bias_name, mean_name, variance_name]) - - if not has_mean_bias: - fluid.io.set_program_state(prog, state) - logger.warning( - "There is no paramters of batch norm in model {}. " - "Skip to fuse batch norm. And load paramters done.".format(path)) - return - - fluid.load(prog, path, exe) - eps = 1e-5 - for names in bn_vars: - scale_name, bias_name, mean_name, var_name = names - - scale = fluid.global_scope().find_var(scale_name).get_tensor() - bias = fluid.global_scope().find_var(bias_name).get_tensor() - mean = fluid.global_scope().find_var(mean_name).get_tensor() - var = fluid.global_scope().find_var(var_name).get_tensor() - - scale_arr = np.array(scale) - bias_arr = np.array(bias) - mean_arr = np.array(mean) - var_arr = np.array(var) - - bn_std = np.sqrt(np.add(var_arr, eps)) - new_scale = np.float32(np.divide(scale_arr, bn_std)) - new_bias = bias_arr - mean_arr * new_scale - - # fuse to scale and bias in affine_channel - scale.set(new_scale, exe.place) - bias.set(new_bias, exe.place) +def load_dygraph_ckpt(model, + optimizer, + pretrain_ckpt=None, + ckpt=None, + ckpt_type='pretrain', + exclude_params=[], + open_debug=False): + + if ckpt_type == 'pretrain': + ckpt = pretrain_ckpt + ckpt = get_ckpt_path(ckpt) + if ckpt is not None and os.path.exists(ckpt): + param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt) + if open_debug: + print("Loading Weights: ", param_state_dict.keys()) + + if len(exclude_params) != 0: + for k in exclude_params: + param_state_dict.pop(k, None) + + if ckpt_type == 'pretrain': + model.backbone.set_dict(param_state_dict) + elif ckpt_type == 'finetune': + model.set_dict(param_state_dict, use_structured_name=True) + else: + model.set_dict(param_state_dict) + + if ckpt_type == 'resume': + if optim_state_dict is None: + print("Can't Resume Last Training's Optimizer State!!!") + else: + optimizer.set_dict(optim_state_dict) + return model + + +def save_dygraph_ckpt(model, optimizer, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + fluid.dygraph.save_dygraph(model.state_dict(), save_dir) + fluid.dygraph.save_dygraph(optimizer.state_dict(), save_dir) + print("Save checkpoint:", save_dir) diff --git a/ppdet/utils/data_structure.py b/ppdet/utils/data_structure.py index 05d845c64..a600af32b 100644 --- a/ppdet/utils/data_structure.py +++ b/ppdet/utils/data_structure.py @@ -35,7 +35,7 @@ class BufferDict(dict): def debug(self, dshape=True, dvalue=True, dtype=False): if self['open_debug']: - if self['debug_names'] is None: + if 'debug_names' not in self.keys(): ditems = self.keys() else: ditems = self['debug_names'] diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py index 8ba53838d..b06accad1 100644 --- a/ppdet/utils/eval_utils.py +++ b/ppdet/utils/eval_utils.py @@ -1,242 +1,7 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging -import numpy as np -import os -import time - -import paddle.fluid as fluid - -from .voc_eval import bbox_eval as voc_bbox_eval -from .post_process import mstest_box_post_process, mstest_mask_post_process, box_flip - -__all__ = ['parse_fetches', 'eval_run', 'eval_results', 'json_eval_results'] - -logger = logging.getLogger(__name__) - - -def parse_fetches(fetches, prog=None, extra_keys=None): - """ - Parse fetch variable infos from model fetches, - values for fetch_list and keys for stat - """ - keys, values = [], [] - cls = [] - for k, v in fetches.items(): - if hasattr(v, 'name'): - keys.append(k) - #v.persistable = True - values.append(v.name) - else: - cls.append(v) - - if prog is not None and extra_keys is not None: - for k in extra_keys: - try: - v = fluid.framework._get_var(k, prog) - keys.append(k) - values.append(v.name) - except Exception: - pass - - return keys, values, cls - - -def length2lod(length_lod): - offset_lod = [0] - for i in length_lod: - offset_lod.append(offset_lod[-1] + i) - return [offset_lod] - - -def get_sub_feed(input, place): - new_dict = {} - res_feed = {} - key_name = ['bbox', 'im_info', 'im_id', 'im_shape', 'bbox_flip'] - for k in key_name: - if k in input.keys(): - new_dict[k] = input[k] - for k in input.keys(): - if 'image' in k: - new_dict[k] = input[k] - for k, v in new_dict.items(): - data_t = fluid.LoDTensor() - data_t.set(v[0], place) - if 'bbox' in k: - lod = length2lod(v[1][0]) - data_t.set_lod(lod) - res_feed[k] = data_t - return res_feed - - -def clean_res(result, keep_name_list): - clean_result = {} - for k in result.keys(): - if k in keep_name_list: - clean_result[k] = result[k] - result.clear() - return clean_result - - -def eval_run(exe, - compile_program, - loader, - keys, - values, - cls, - cfg=None, - sub_prog=None, - sub_keys=None, - sub_values=None, - resolution=None): - """ - Run evaluation program, return program outputs. - """ - iter_id = 0 - results = [] - if len(cls) != 0: - values = [] - for i in range(len(cls)): - _, accum_map = cls[i].get_map_var() - cls[i].reset(exe) - values.append(accum_map) - - images_num = 0 - start_time = time.time() - has_bbox = 'bbox' in keys - - try: - loader.start() - while True: - outs = exe.run(compile_program, - fetch_list=values, - return_numpy=False) - res = { - k: (np.array(v), v.recursive_sequence_lengths()) - for k, v in zip(keys, outs) - } - multi_scale_test = getattr(cfg, 'MultiScaleTEST', None) - mask_multi_scale_test = multi_scale_test and 'Mask' in cfg.architecture - - if multi_scale_test: - post_res = mstest_box_post_process(res, multi_scale_test, - cfg.num_classes) - res.update(post_res) - if mask_multi_scale_test: - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() - sub_feed = get_sub_feed(res, place) - sub_prog_outs = exe.run(sub_prog, - feed=sub_feed, - fetch_list=sub_values, - return_numpy=False) - sub_prog_res = { - k: (np.array(v), v.recursive_sequence_lengths()) - for k, v in zip(sub_keys, sub_prog_outs) - } - post_res = mstest_mask_post_process(sub_prog_res, cfg) - res.update(post_res) - if multi_scale_test: - res = clean_res( - res, ['im_info', 'bbox', 'im_id', 'im_shape', 'mask']) - if 'mask' in res: - from ppdet.utils.post_process import mask_encode - res['mask'] = mask_encode(res, resolution) - post_config = getattr(cfg, 'PostProcess', None) - if 'Corner' in cfg.architecture and post_config is not None: - from ppdet.utils.post_process import corner_post_process - corner_post_process(res, post_config, cfg.num_classes) - results.append(res) - if iter_id % 100 == 0: - logger.info('Test iter {}'.format(iter_id)) - iter_id += 1 - if len(res['bbox'][1]) == 0: - has_bbox = False - images_num += len(res['bbox'][1][0]) if has_bbox else 1 - except (StopIteration, fluid.core.EOFException): - loader.reset() - logger.info('Test finish iter {}'.format(iter_id)) - - end_time = time.time() - fps = images_num / (end_time - start_time) - if has_bbox: - logger.info('Total number of images: {}, inference time: {} fps.'. - format(images_num, fps)) - else: - logger.info('Total iteration: {}, inference time: {} batch/s.'.format( - images_num, fps)) - - return results - - -def eval_results(results, - metric, - num_classes, - resolution=None, - is_bbox_normalized=False, - output_directory=None, - map_type='11point', - dataset=None, - save_only=False): - """Evaluation for evaluation program results""" - box_ap_stats = [] - if metric == 'COCO': - from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval - anno_file = dataset.get_anno() - with_background = dataset.with_background - if 'proposal' in results[0]: - output = 'proposal.json' - if output_directory: - output = os.path.join(output_directory, 'proposal.json') - proposal_eval(results, anno_file, output) - if 'bbox' in results[0]: - output = 'bbox.json' - if output_directory: - output = os.path.join(output_directory, 'bbox.json') - - box_ap_stats = bbox_eval( - results, - anno_file, - output, - with_background, - is_bbox_normalized=is_bbox_normalized, - save_only=save_only) - - if 'mask' in results[0]: - output = 'mask.json' - if output_directory: - output = os.path.join(output_directory, 'mask.json') - mask_eval( - results, anno_file, output, resolution, save_only=save_only) - else: - if 'accum_map' in results[-1]: - res = np.mean(results[-1]['accum_map'][0]) - logger.info('mAP: {:.2f}'.format(res * 100.)) - box_ap_stats.append(res * 100.) - elif 'bbox' in results[0]: - box_ap = voc_bbox_eval( - results, - num_classes, - is_bbox_normalized=is_bbox_normalized, - map_type=map_type) - box_ap_stats.append(box_ap) - return box_ap_stats - def json_eval_results(metric, json_directory=None, dataset=None): """ @@ -259,3 +24,51 @@ def json_eval_results(metric, json_directory=None, dataset=None): cocoapi_eval(v_json, coco_eval_style[i], anno_file=anno_file) else: logger.info("{} not exists!".format(v_json)) + + +def coco_eval_results(outs_res=None, + include_mask=False, + batch_size=1, + dataset=None): + print("start evaluate bbox using coco api") + import io + import six + import json + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + from ppdet.py_op.post_process import get_det_res, get_seg_res + anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path) + cocoGt = COCO(anno_file) + catid = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} + + if outs_res is not None and len(outs_res) > 0: + det_res = [] + for outs in outs_res: + det_res += get_det_res(outs['bbox_nums'], outs['bbox'], + outs['im_id'], outs['im_shape'], catid, + batch_size) + + with io.open("bbox_eval.json", 'w') as outfile: + encode_func = unicode if six.PY2 else str + outfile.write(encode_func(json.dumps(det_res))) + + cocoDt = cocoGt.loadRes("bbox_eval.json") + cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + if outs_res is not None and len(outs_res) > 0 and include_mask: + seg_res = [] + for outs in outs_res: + seg_res += get_seg_res(outs['bbox_nums'], outs['mask'], + outs['im_id'], catid, batch_size) + + with io.open("mask_eval.json", 'w') as outfile: + encode_func = unicode if six.PY2 else str + outfile.write(encode_func(json.dumps(seg_res))) + + cocoSg = cocoGt.loadRes("mask_eval.json") + cocoEval = COCOeval(cocoGt, cocoSg, 'bbox') + cocoEval.evaluate() + cocoEval.accumulate() diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/eval.py b/tools/eval.py new file mode 100755 index 000000000..bec9b8fc0 --- /dev/null +++ b/tools/eval.py @@ -0,0 +1,93 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import time +# ignore numba warning +import warnings +warnings.filterwarnings('ignore') +import random +import numpy as np +import paddle.fluid as fluid +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.cli import ArgsParser +from ppdet.utils.eval_utils import coco_eval_results +from ppdet.data.reader import create_reader + + +def parse_args(): + parser = ArgsParser() + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + + parser.add_argument( + '--json_eval', action='store_true', default=False, help='') + + parser.add_argument( + '--use_gpu', action='store_true', default=False, help='') + + args = parser.parse_args() + return args + + +def run(FLAGS, cfg): + + # Model + main_arch = cfg.architecture + model = create(cfg.architecture, mode='infer', open_debug=cfg.open_debug) + + # Init Model + if os.path.isfile(cfg.weights): + param_state_dict, opti_state_dict = fluid.load_dygraph(cfg.weights) + model.set_dict(param_state_dict) + + # Data Reader + if FLAGS.use_gpu: + devices_num = 1 + else: + devices_num = int(os.environ.get('CPU_NUM', 1)) + eval_reader = create_reader(cfg.EvalReader, devices_num=devices_num) + + # Run Eval + outs_res = [] + for iter_id, data in enumerate(eval_reader()): + start_time = time.time() + + # forward + model.eval() + outs = model(data, cfg['EvalReader']['inputs_def']['fields']) + outs_res.append(outs) + + # log + cost_time = time.time() - start_time + print("Eval iter: {}, time: {}".format(iter_id, cost_time)) + + # Metric + coco_eval_results( + outs_res, + include_mask=True if 'MaskHed' in cfg else False, + dataset=cfg['EvalReader']['dataset']) + + +def main(): + FLAGS = parse_args() + + cfg = load_config(FLAGS.config) + merge_config(FLAGS.opt) + check_config(cfg) + check_gpu(cfg.use_gpu) + check_version() + + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env() + .dev_id) if cfg.use_gpu else fluid.CPUPlace() + + with fluid.dygraph.guard(place): + run(FLAGS, cfg) + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py new file mode 100755 index 000000000..3c1865ede --- /dev/null +++ b/tools/train.py @@ -0,0 +1,198 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import time +# ignore numba warning +import warnings +warnings.filterwarnings('ignore') +import random +import numpy as np +import paddle.fluid as fluid +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.data.reader import create_reader +from ppdet.utils.check import check_gpu, check_version, check_config +from ppdet.utils.cli import ArgsParser +from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt + + +def parse_args(): + parser = ArgsParser() + parser.add_argument( + "-ckpt_type", + default='pretrain', + type=str, + help="Loading Checkpoints only support 'pretrain', 'finetune', 'resume'." + ) + + parser.add_argument( + "--fp16", + action='store_true', + default=False, + help="Enable mixed precision training.") + parser.add_argument( + "--loss_scale", + default=8., + type=float, + help="Mixed precision training loss scale.") + parser.add_argument( + "--eval", + action='store_true', + default=False, + help="Whether to perform evaluation in train") + parser.add_argument( + "--output_eval", + default=None, + type=str, + help="Evaluation directory, default is current directory.") + parser.add_argument( + "--use_tb", + type=bool, + default=False, + help="whether to record the data to Tensorboard.") + parser.add_argument( + '--tb_log_dir', + type=str, + default="tb_log_dir/scalar", + help='Tensorboard logging directory for scalar.') + parser.add_argument( + "--enable_ce", + type=bool, + default=False, + help="If set True, enable continuous evaluation job." + "This flag is only used for internal test.") + parser.add_argument( + "--use_gpu", action='store_true', default=False, help="data parallel") + parser.add_argument( + "--use_parallel", + action='store_true', + default=False, + help="data parallel") + + parser.add_argument( + '--is_profiler', + type=int, + default=0, + help='The switch of profiler tools. (used for benchmark)') + + args = parser.parse_args() + return args + + +def run(FLAGS, cfg): + env = os.environ + FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env + if FLAGS.dist: + trainer_id = int(env['PADDLE_TRAINER_ID']) + local_seed = (99 + trainer_id) + random.seed(local_seed) + np.random.seed(local_seed) + + if FLAGS.enable_ce or cfg.open_debug: + fluid.default_startup_program().random_seed = 1000 + fluid.default_main_program().random_seed = 1000 + random.seed(0) + np.random.seed(0) + + if FLAGS.use_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + parallel_log = "Note: use parallel " + + # Model + main_arch = cfg.architecture + model = create(cfg.architecture, mode='train', open_debug=cfg.open_debug) + + # Parallel Model + if FLAGS.use_parallel: + #strategy = fluid.dygraph.parallel.prepare_context() + model = fluid.dygraph.parallel.DataParallel(model, strategy) + parallel_log += "with data parallel!" + print(parallel_log) + + # Optimizer + lr = create('LearningRate')() + optimizer = create('OptimizerBuilder')(lr, model.parameters()) + + # Init Model & Optimzer + model = load_dygraph_ckpt( + model, + optimizer, + cfg.pretrain_weights, + cfg.weights, + FLAGS.ckpt_type, + open_debug=cfg.open_debug) + + # Data Reader + start_iter = 0 + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count( + ) if FLAGS.use_parallel else 1 + else: + devices_num = int(os.environ.get('CPU_NUM', 1)) + + train_reader = create_reader( + cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, + cfg, + devices_num=devices_num) + + # Run Train + for iter_id, data in enumerate(train_reader()): + start_time = time.time() + + # Model Forward + model.train() + outputs = model(data, cfg['TrainReader']['inputs_def']['fields']) + + # Model Backward + loss = outputs['loss'] + if FLAGS.use_parallel: + loss = model.scale_loss(loss) + loss.backward() + model.apply_collective_grads() + else: + loss.backward() + optimizer.minimize(loss) + model.clear_gradients() + + # Log state + cost_time = time.time() - start_time + # TODO: check this method + curr_lr = optimizer.current_step_lr() + log_info = "iter: {}, time: {:.4f}, lr: {:.6f}".format( + iter_id, cost_time, curr_lr) + for k, v in outputs.items(): + log_info += ", {}: {:.6f}".format(k, v.numpy()[0]) + print(log_info) + + # Debug + if cfg.open_debug and iter_id > 10: + break + + # Save Stage + if iter_id > 0 and iter_id % int(cfg.snapshot_iter) == 0: + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_name = str( + iter_id) if iter_id != cfg.max_iters - 1 else "model_final" + save_dir = os.path.join(cfg.save_dir, cfg_name, save_name) + save_dygraph_ckpt(model, optimizer, save_dir) + + +def main(): + FLAGS = parse_args() + + cfg = load_config(FLAGS.config) + merge_config(FLAGS.opt) + check_config(cfg) + check_gpu(cfg.use_gpu) + check_version() + + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ + if FLAGS.use_parallel else fluid.CUDAPlace(0) \ + if cfg.use_gpu else fluid.CPUPlace() + + with fluid.dygraph.guard(place): + run(FLAGS, cfg) + + +if __name__ == "__main__": + main() -- GitLab