diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index c9ea09af2f7250d67cb005345a48d59107ab7eab..f04fd6b3380c915abaf1e8104d8901268d12775f 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -23,7 +23,7 @@ else: import numpy as np from paddle.io import DataLoader, DistributedBatchSampler -from paddle.fluid.dataloader.collate import default_collate_fn +from .utils import default_collate_fn from ppdet.core.workspace import register from . import transform diff --git a/ppdet/data/utils.py b/ppdet/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02573e61484bc5ef07353dbef124c8afa54ccc64 --- /dev/null +++ b/ppdet/data/utils.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numbers +import numpy as np + +try: + from collections.abc import Sequence, Mapping +except: + from collections import Sequence, Mapping + + +def default_collate_fn(batch): + """ + Default batch collating function for :code:`paddle.io.DataLoader`, + get input data as a list of sample datas, each element in list + if the data of a sample, and sample data should composed of list, + dictionary, string, number, numpy array, this + function will parse input data recursively and stack number, + numpy array and paddle.Tensor datas as batch datas. e.g. for + following input data: + [{'image': np.array(shape=[3, 224, 224]), 'label': 1}, + {'image': np.array(shape=[3, 224, 224]), 'label': 3}, + {'image': np.array(shape=[3, 224, 224]), 'label': 4}, + {'image': np.array(shape=[3, 224, 224]), 'label': 5},] + + + This default collate function zipped each number and numpy array + field together and stack each field as the batch field as follows: + {'image': np.array(shape=[4, 3, 224, 224]), 'label': np.array([1, 3, 4, 5])} + Args: + batch(list of sample data): batch should be a list of sample data. + + Returns: + Batched data: batched each number, numpy array and paddle.Tensor + in input data. + """ + sample = batch[0] + if isinstance(sample, np.ndarray): + batch = np.stack(batch, axis=0) + return batch + elif isinstance(sample, numbers.Number): + batch = np.array(batch) + return batch + elif isinstance(sample, (str, bytes)): + return batch + elif isinstance(sample, Mapping): + return { + key: default_collate_fn([d[key] for d in batch]) + for key in sample + } + elif isinstance(sample, Sequence): + sample_fields_num = len(sample) + if not all(len(sample) == sample_fields_num for sample in iter(batch)): + raise RuntimeError( + "fileds number not same among samples in a batch") + return [default_collate_fn(fields) for fields in zip(*batch)] + + raise TypeError("batch data con only contains: tensor, numpy.ndarray, " + "dict, list, number, but got {}".format(type(sample))) diff --git a/ppdet/modeling/architectures/yolox.py b/ppdet/modeling/architectures/yolox.py index 83b318eecbff7306897ab844d94846d76f1d447c..8e02e9ef7ecce137013ec2e7707dc04e3afabb28 100644 --- a/ppdet/modeling/architectures/yolox.py +++ b/ppdet/modeling/architectures/yolox.py @@ -23,7 +23,6 @@ import random import paddle import paddle.nn.functional as F import paddle.distributed as dist -from ppdet.modeling.ops import paddle_distributed_is_initialized __all__ = ['YOLOX'] diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py index 2a641d4cc41928671d13dd14cddb73295d0f8386..6406d7bce5b796c125cd489886e9622a3f4ede97 100644 --- a/ppdet/modeling/assigners/atss_assigner.py +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -22,8 +22,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..ops import iou_similarity -from ..bbox_utils import iou_similarity as batch_iou_similarity +from ..bbox_utils import iou_similarity, batch_iou_similarity from ..bbox_utils import bbox_center from .utils import (check_points_inside_bboxes, compute_max_iou_anchor, compute_max_iou_gt) diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index 1a82c15237a07d3993460629ccb8317466da87fb..949e899192f65082c4cf422ae283e6a11834d3f6 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -21,7 +21,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..bbox_utils import iou_similarity +from ..bbox_utils import batch_iou_similarity from .utils import (gather_topk_anchors, check_points_inside_bboxes, compute_max_iou_anchor) diff --git a/ppdet/modeling/backbones/swin_transformer.py b/ppdet/modeling/backbones/swin_transformer.py index bd541784e32d321edd4db0082e6622cea8734590..e5f48bc6fdb7e5592b772c42b2881a5e17cbdc38 100644 --- a/ppdet/modeling/backbones/swin_transformer.py +++ b/ppdet/modeling/backbones/swin_transformer.py @@ -482,8 +482,7 @@ class BasicLayer(nn.Layer): # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = paddle.fluid.layers.zeros( - [1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1 + img_mask = paddle.zeros([1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 11f504f804c98b3f7a0a44d8b6f4481577d5aa6d..f895340c7e8da8606bfd0f55b1e9b84d36bfd549 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -278,8 +278,8 @@ def decode_yolo(box, anchor, downsample_ratio): return [x1, y1, w1, h1] -def iou_similarity(box1, box2, eps=1e-9): - """Calculate iou of box1 and box2 +def batch_iou_similarity(box1, box2, eps=1e-9): + """Calculate iou of box1 and box2 in batch Args: box1 (Tensor): box with the shape [N, M1, 4] @@ -866,3 +866,26 @@ def bbox2delta_v2(src_boxes, stds = paddle.to_tensor(stds, place=src_boxes.place) deltas = (deltas - means) / stds return deltas + + +def iou_similarity(box1, box2, eps=1e-10): + """Calculate iou of box1 and box2 + + Args: + box1 (Tensor): box with the shape [M1, 4] + box2 (Tensor): box with the shape [M2, 4] + + Return: + iou (Tensor): iou between box1 and box2 with the shape [M1, M2] + """ + box1 = box1.unsqueeze(1) # [M1, 4] -> [M1, 1, 4] + box2 = box2.unsqueeze(0) # [M2, 4] -> [1, M2, 4] + px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4] + gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4] + x1y1 = paddle.maximum(px1y1, gx1y1) + x2y2 = paddle.minimum(px2y2, gx2y2) + overlap = (x2y2 - x1y1).clip(0).prod(-1) + area1 = (px2y2 - px1y1).clip(0).prod(-1) + area2 = (gx2y2 - gx1y1).clip(0).prod(-1) + union = area1 + area2 - overlap + eps + return overlap / union diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index 44f3a214cc8d467308fc1861f1e30e61377d9c79..a63e7c90ca76f54934f9e28858e135cdb5c04d16 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -23,7 +23,6 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr from paddle.nn.initializer import Normal, Constant -from paddle.fluid.dygraph import parallel_helper from ppdet.modeling.ops import get_static_shape from ..initializer import normal_ @@ -726,8 +725,7 @@ class PicoHeadV2(GFLHead): loss_dfl = paddle.zeros([1]) avg_factor = flatten_assigned_scores.sum() - if paddle.fluid.core.is_compiled_with_dist( - ) and parallel_helper._is_parallel_ctx_initialized(): + if paddle.distributed.get_world_size() > 1: paddle.distributed.all_reduce(avg_factor) avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 3babb2703c1991e56634f14d7a4578588f58b791..b012806a536fe5db2f06fddfd68610736f8fe58e 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -22,7 +22,7 @@ from ..losses import GIoULoss from ..initializer import bias_init_with_prob, constant_, normal_ from ..assigners.utils import generate_anchors_for_grid_cell from ppdet.modeling.backbones.cspresnet import ConvBNLayer -from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn +from ppdet.modeling.ops import get_static_shape, get_act_fn from ppdet.modeling.layers import MultiClassNMS __all__ = ['PPYOLOEHead'] @@ -343,7 +343,7 @@ class PPYOLOEHead(nn.Layer): loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l) assigned_scores_sum = assigned_scores.sum() - if paddle_distributed_is_initialized(): + if paddle.distributed.get_world_size() > 1: paddle.distributed.all_reduce(assigned_scores_sum) assigned_scores_sum = paddle.clip( assigned_scores_sum / paddle.distributed.get_world_size(), diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 8238e89734909b23dd0956b838d18d7f8cf019de..ba9a385e145040fb3301e1d670b2c1e33dc1c123 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -554,10 +554,15 @@ class YOLOBox(object): origin_shape = im_shape / scale_factor origin_shape = paddle.cast(origin_shape, 'int32') for i, head_out in enumerate(yolo_head_out): - boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i], - self.num_classes, self.conf_thresh, - self.downsample_ratio // 2**i, - self.clip_bbox, self.scale_x_y) + boxes, scores = paddle.vision.ops.yolo_box( + head_out, + origin_shape, + anchors[i], + self.num_classes, + self.conf_thresh, + self.downsample_ratio // 2**i, + self.clip_bbox, + scale_x_y=self.scale_x_y) boxes_list.append(boxes) scores_list.append(paddle.transpose(scores, perm=[0, 2, 1])) yolo_boxes = paddle.concat(boxes_list, axis=1) @@ -622,94 +627,6 @@ class SSDBox(object): return output_boxes, output_scores -@register -@serializable -class AnchorGrid(object): - """Generate anchor grid - - Args: - image_size (int or list): input image size, may be a single integer or - list of [h, w]. Default: 512 - min_level (int): min level of the feature pyramid. Default: 3 - max_level (int): max level of the feature pyramid. Default: 7 - anchor_base_scale: base anchor scale. Default: 4 - num_scales: number of anchor scales. Default: 3 - aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]] - """ - - def __init__(self, - image_size=512, - min_level=3, - max_level=7, - anchor_base_scale=4, - num_scales=3, - aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]): - super(AnchorGrid, self).__init__() - if isinstance(image_size, Integral): - self.image_size = [image_size, image_size] - else: - self.image_size = image_size - for dim in self.image_size: - assert dim % 2 ** max_level == 0, \ - "image size should be multiple of the max level stride" - self.min_level = min_level - self.max_level = max_level - self.anchor_base_scale = anchor_base_scale - self.num_scales = num_scales - self.aspect_ratios = aspect_ratios - - @property - def base_cell(self): - if not hasattr(self, '_base_cell'): - self._base_cell = self.make_cell() - return self._base_cell - - def make_cell(self): - scales = [2**(i / self.num_scales) for i in range(self.num_scales)] - scales = np.array(scales) - ratios = np.array(self.aspect_ratios) - ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1) - hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1) - anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs)) - return anchors - - def make_grid(self, stride): - cell = self.base_cell * stride * self.anchor_base_scale - x_steps = np.arange(stride // 2, self.image_size[1], stride) - y_steps = np.arange(stride // 2, self.image_size[0], stride) - offset_x, offset_y = np.meshgrid(x_steps, y_steps) - offset_x = offset_x.flatten() - offset_y = offset_y.flatten() - offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1) - offsets = offsets[:, np.newaxis, :] - return (cell + offsets).reshape(-1, 4) - - def generate(self): - return [ - self.make_grid(2**l) - for l in range(self.min_level, self.max_level + 1) - ] - - def __call__(self): - if not hasattr(self, '_anchor_vars'): - anchor_vars = [] - helper = LayerHelper('anchor_grid') - for idx, l in enumerate(range(self.min_level, self.max_level + 1)): - stride = 2**l - anchors = self.make_grid(stride) - var = helper.create_parameter( - attr=ParamAttr(name='anchors_{}'.format(idx)), - shape=anchors.shape, - dtype='float32', - stop_gradient=True, - default_initializer=NumpyArrayInitializer(anchors)) - anchor_vars.append(var) - var.persistable = True - self._anchor_vars = anchor_vars - - return self._anchor_vars - - @register @serializable class FCOSBox(object): diff --git a/ppdet/modeling/losses/ssd_loss.py b/ppdet/modeling/losses/ssd_loss.py index 62aecc1f33a104531edc2a77015e27847bb92506..2ab94f2b5bbf1f31fe47d186a92ac805cdf6daf3 100644 --- a/ppdet/modeling/losses/ssd_loss.py +++ b/ppdet/modeling/losses/ssd_loss.py @@ -20,8 +20,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..ops import iou_similarity -from ..bbox_utils import bbox2delta +from ..bbox_utils import iou_similarity, bbox2delta __all__ = ['SSDLoss'] diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 657959cd7e55cf43d6362f03e1a4c1204b814c07..1ba05f2c8eae530e44e20d21375f7cf9b9cd1fb0 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -21,7 +21,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..bbox_utils import decode_yolo, xywh2xyxy, iou_similarity +from ..bbox_utils import decode_yolo, xywh2xyxy, batch_iou_similarity __all__ = ['YOLOv3Loss'] @@ -80,7 +80,7 @@ class YOLOv3Loss(nn.Layer): gwh = gbox[:, :, 0:2] + gbox[:, :, 2:4] * 0.5 gbox = paddle.concat([gxy, gwh], axis=-1) - iou = iou_similarity(pbox, gbox) + iou = batch_iou_similarity(pbox, gbox) iou.stop_gradient = True iou_max = iou.max(2) # [N, M1] iou_mask = paddle.cast(iou_max <= self.ignore_thresh, dtype=pbox.dtype) diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 0a79d7fbf903a4307f024f125f221783a1ae3bb3..baf83ed73d270288dbbed8ae714b65d681cfab7c 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -19,23 +19,17 @@ from paddle import ParamAttr from paddle.regularizer import L2Decay from paddle import _C_ops -from paddle.fluid.framework import Variable, in_dygraph_mode -from paddle.fluid import core -from paddle.fluid.dygraph import parallel_helper -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype +from paddle import in_dynamic_mode +from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype __all__ = [ 'roi_pool', 'roi_align', 'prior_box', 'generate_proposals', - 'iou_similarity', 'box_coder', - 'yolo_box', 'multiclass_nms', 'distribute_fpn_proposals', - 'collect_fpn_proposals', 'matrix_nms', 'batch_norm', 'mish', @@ -189,7 +183,7 @@ def roi_pool(input, output_size = (output_size, output_size) pooled_height, pooled_width = output_size - if in_dygraph_mode(): + if in_dynamic_mode(): assert rois_num is not None, "rois_num should not be None in dygraph mode." pool_out, argmaxes = _C_ops.roi_pool( input, rois, rois_num, "pooled_height", pooled_height, @@ -296,7 +290,7 @@ def roi_align(input, pooled_height, pooled_width = output_size - if in_dygraph_mode(): + if in_dynamic_mode(): assert rois_num is not None, "rois_num should not be None in dygraph mode." align_out = _C_ops.roi_align( input, rois, rois_num, "pooled_height", pooled_height, @@ -332,183 +326,6 @@ def roi_align(input, return align_out -@paddle.jit.not_to_static -def iou_similarity(x, y, box_normalized=True, name=None): - """ - Computes intersection-over-union (IOU) between two box lists. - Box list 'X' should be a LoDTensor and 'Y' is a common Tensor, - boxes in 'Y' are shared by all instance of the batched inputs of X. - Given two boxes A and B, the calculation of IOU is as follows: - - $$ - IOU(A, B) = - \\frac{area(A\\cap B)}{area(A)+area(B)-area(A\\cap B)} - $$ - - Args: - x (Tensor): Box list X is a 2-D Tensor with shape [N, 4] holds N - boxes, each box is represented as [xmin, ymin, xmax, ymax], - the shape of X is [N, 4]. [xmin, ymin] is the left top - coordinate of the box if the input is image feature map, they - are close to the origin of the coordinate system. - [xmax, ymax] is the right bottom coordinate of the box. - The data type is float32 or float64. - y (Tensor): Box list Y holds M boxes, each box is represented as - [xmin, ymin, xmax, ymax], the shape of X is [N, 4]. - [xmin, ymin] is the left top coordinate of the box if the - input is image feature map, and [xmax, ymax] is the right - bottom coordinate of the box. The data type is float32 or float64. - box_normalized(bool): Whether treat the priorbox as a normalized box. - Set true by default. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - Tensor: The output of iou_similarity op, a tensor with shape [N, M] - representing pairwise iou scores. The data type is same with x. - - Examples: - .. code-block:: python - - import paddle - from ppdet.modeling import ops - paddle.enable_static() - - x = paddle.static.data(name='x', shape=[None, 4], dtype='float32') - y = paddle.static.data(name='y', shape=[None, 4], dtype='float32') - iou = ops.iou_similarity(x=x, y=y) - """ - - if in_dygraph_mode(): - out = _C_ops.iou_similarity(x, y, 'box_normalized', box_normalized) - return out - else: - helper = LayerHelper("iou_similarity", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type="iou_similarity", - inputs={"X": x, - "Y": y}, - attrs={"box_normalized": box_normalized}, - outputs={"Out": out}) - return out - - -@paddle.jit.not_to_static -def collect_fpn_proposals(multi_rois, - multi_scores, - min_level, - max_level, - post_nms_top_n, - rois_num_per_level=None, - name=None): - """ - - **This OP only supports LoDTensor as input**. Concat multi-level RoIs - (Region of Interest) and select N RoIs with respect to multi_scores. - This operation performs the following steps: - - 1. Choose num_level RoIs and scores as input: num_level = max_level - min_level - 2. Concat multi-level RoIs and scores - 3. Sort scores and select post_nms_top_n scores - 4. Gather RoIs by selected indices from scores - 5. Re-sort RoIs by corresponding batch_id - - Args: - multi_rois(list): List of RoIs to collect. Element in list is 2-D - LoDTensor with shape [N, 4] and data type is float32 or float64, - N is the number of RoIs. - multi_scores(list): List of scores of RoIs to collect. Element in list - is 2-D LoDTensor with shape [N, 1] and data type is float32 or - float64, N is the number of RoIs. - min_level(int): The lowest level of FPN layer to collect - max_level(int): The highest level of FPN layer to collect - post_nms_top_n(int): The number of selected RoIs - rois_num_per_level(list, optional): The List of RoIs' numbers. - Each element is 1-D Tensor which contains the RoIs' number of each - image on each level and the shape is [B] and data type is - int32, B is the number of images. If it is not None then return - a 1-D Tensor contains the output RoIs' number of each image and - the shape is [B]. Default: None - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - Variable: - - fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is - float32 or float64. Selected RoIs. - - rois_num(Tensor): 1-D Tensor contains the RoIs's number of each - image. The shape is [B] and data type is int32. B is the number of - images. - - Examples: - .. code-block:: python - - import paddle - from ppdet.modeling import ops - paddle.enable_static() - multi_rois = [] - multi_scores = [] - for i in range(4): - multi_rois.append(paddle.static.data( - name='roi_'+str(i), shape=[None, 4], dtype='float32', lod_level=1)) - for i in range(4): - multi_scores.append(paddle.static.data( - name='score_'+str(i), shape=[None, 1], dtype='float32', lod_level=1)) - - fpn_rois = ops.collect_fpn_proposals( - multi_rois=multi_rois, - multi_scores=multi_scores, - min_level=2, - max_level=5, - post_nms_top_n=2000) - """ - check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals') - check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals') - num_lvl = max_level - min_level + 1 - input_rois = multi_rois[:num_lvl] - input_scores = multi_scores[:num_lvl] - - if in_dygraph_mode(): - assert rois_num_per_level is not None, "rois_num_per_level should not be None in dygraph mode." - attrs = ('post_nms_topN', post_nms_top_n) - output_rois, rois_num = _C_ops.collect_fpn_proposals( - input_rois, input_scores, rois_num_per_level, *attrs) - return output_rois, rois_num - - else: - helper = LayerHelper('collect_fpn_proposals', **locals()) - dtype = helper.input_dtype('multi_rois') - check_dtype(dtype, 'multi_rois', ['float32', 'float64'], - 'collect_fpn_proposals') - output_rois = helper.create_variable_for_type_inference(dtype) - output_rois.stop_gradient = True - - inputs = { - 'MultiLevelRois': input_rois, - 'MultiLevelScores': input_scores, - } - outputs = {'FpnRois': output_rois} - if rois_num_per_level is not None: - inputs['MultiLevelRoIsNum'] = rois_num_per_level - rois_num = helper.create_variable_for_type_inference(dtype='int32') - rois_num.stop_gradient = True - outputs['RoisNum'] = rois_num - else: - rois_num = None - helper.append_op( - type='collect_fpn_proposals', - inputs=inputs, - outputs=outputs, - attrs={'post_nms_topN': post_nms_top_n}) - return output_rois, rois_num - - @paddle.jit.not_to_static def distribute_fpn_proposals(fpn_rois, min_level, @@ -587,7 +404,7 @@ def distribute_fpn_proposals(fpn_rois, """ num_lvl = max_level - min_level + 1 - if in_dygraph_mode(): + if in_dynamic_mode(): assert rois_num is not None, "rois_num should not be None in dygraph mode." attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level', refer_level, 'refer_scale', refer_scale, 'pixel_offset', @@ -638,143 +455,6 @@ def distribute_fpn_proposals(fpn_rois, return multi_rois, restore_ind, rois_num_per_level -@paddle.jit.not_to_static -def yolo_box( - x, - origin_shape, - anchors, - class_num, - conf_thresh, - downsample_ratio, - clip_bbox=True, - scale_x_y=1., - name=None, ): - """ - - This operator generates YOLO detection boxes from output of YOLOv3 network. - - The output of previous network is in shape [N, C, H, W], while H and W - should be the same, H and W specify the grid size, each grid point predict - given number boxes, this given number, which following will be represented as S, - is specified by the number of anchors. In the second dimension(the channel - dimension), C should be equal to S * (5 + class_num), class_num is the object - category number of source dataset(such as 80 in coco dataset), so the - second(channel) dimension, apart from 4 box location coordinates x, y, w, h, - also includes confidence score of the box and class one-hot key of each anchor - box. - Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box - predictions should be as follows: - $$ - b_x = \\sigma(t_x) + c_x - $$ - $$ - b_y = \\sigma(t_y) + c_y - $$ - $$ - b_w = p_w e^{t_w} - $$ - $$ - b_h = p_h e^{t_h} - $$ - in the equation above, :math:`c_x, c_y` is the left top corner of current grid - and :math:`p_w, p_h` is specified by anchors. - The logistic regression value of the 5th channel of each anchor prediction boxes - represents the confidence score of each prediction box, and the logistic - regression value of the last :attr:`class_num` channels of each anchor prediction - boxes represents the classifcation scores. Boxes with confidence scores less than - :attr:`conf_thresh` should be ignored, and box final scores is the product of - confidence scores and classification scores. - $$ - score_{pred} = score_{conf} * score_{class} - $$ - - Args: - x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with shape of [N, C, H, W]. - The second dimension(C) stores box locations, confidence score and - classification one-hot keys of each anchor box. Generally, X should be the output of YOLOv3 network. - The data type is float32 or float64. - origin_shape (Tensor): The image size tensor of YoloBox operator, This is a 2-D tensor with shape of [N, 2]. - This tensor holds height and width of each input image used for resizing output box in input image - scale. The data type is int32. - anchors (list|tuple): The anchor width and height, it will be parsed pair by pair. - class_num (int): The number of classes to predict. - conf_thresh (float): The confidence scores threshold of detection boxes. Boxes with confidence scores - under threshold should be ignored. - downsample_ratio (int): The downsample ratio from network input to YoloBox operator input, - so 32, 16, 8 should be set for the first, second, and thrid YoloBox operators. - clip_bbox (bool): Whether clip output bonding box in Input(ImgSize) boundary. Default true. - scale_x_y (float): Scale the center point of decoded bounding box. Default 1.0. - name (string): The default value is None. Normally there is no need - for user to set this property. For more information, - please refer to :ref:`api_guide_Name` - - Returns: - boxes Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, N is the batch num, - M is output box number, and the 3rd dimension stores [xmin, ymin, xmax, ymax] coordinates of boxes. - scores Tensor: A 3-D tensor with shape [N, M, :attr:`class_num`], the coordinates of boxes, N is the batch num, - M is output box number. - - Raises: - TypeError: Attr anchors of yolo box must be list or tuple - TypeError: Attr class_num of yolo box must be an integer - TypeError: Attr conf_thresh of yolo box must be a float number - - Examples: - - .. code-block:: python - - import paddle - from ppdet.modeling import ops - - paddle.enable_static() - x = paddle.static.data(name='x', shape=[None, 255, 13, 13], dtype='float32') - img_size = paddle.static.data(name='img_size',shape=[None, 2],dtype='int64') - anchors = [10, 13, 16, 30, 33, 23] - boxes,scores = ops.yolo_box(x=x, img_size=img_size, class_num=80, anchors=anchors, - conf_thresh=0.01, downsample_ratio=32) - """ - helper = LayerHelper('yolo_box', **locals()) - - if not isinstance(anchors, list) and not isinstance(anchors, tuple): - raise TypeError("Attr anchors of yolo_box must be list or tuple") - if not isinstance(class_num, int): - raise TypeError("Attr class_num of yolo_box must be an integer") - if not isinstance(conf_thresh, float): - raise TypeError("Attr ignore_thresh of yolo_box must be a float number") - - if in_dygraph_mode(): - attrs = ('anchors', anchors, 'class_num', class_num, 'conf_thresh', - conf_thresh, 'downsample_ratio', downsample_ratio, 'clip_bbox', - clip_bbox, 'scale_x_y', scale_x_y) - boxes, scores = _C_ops.yolo_box(x, origin_shape, *attrs) - return boxes, scores - else: - boxes = helper.create_variable_for_type_inference(dtype=x.dtype) - scores = helper.create_variable_for_type_inference(dtype=x.dtype) - - attrs = { - "anchors": anchors, - "class_num": class_num, - "conf_thresh": conf_thresh, - "downsample_ratio": downsample_ratio, - "clip_bbox": clip_bbox, - "scale_x_y": scale_x_y, - } - - helper.append_op( - type='yolo_box', - inputs={ - "X": x, - "ImgSize": origin_shape, - }, - outputs={ - 'Boxes': boxes, - 'Scores': scores, - }, - attrs=attrs) - return boxes, scores - - @paddle.jit.not_to_static def prior_box(input, image, @@ -877,7 +557,7 @@ def prior_box(input, max_sizes = [max_sizes] cur_max_sizes = max_sizes - if in_dygraph_mode(): + if in_dynamic_mode(): attrs = ('min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, 'variances', variance, 'flip', flip, 'clip', clip, 'step_w', steps[0], 'step_h', steps[1], 'offset', offset, @@ -1022,7 +702,7 @@ def multiclass_nms(bboxes, """ helper = LayerHelper('multiclass_nms3', **locals()) - if in_dygraph_mode(): + if in_dynamic_mode(): attrs = ('background_label', background_label, 'score_threshold', score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold', nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta, @@ -1163,7 +843,7 @@ def matrix_nms(bboxes, check_type(gaussian_sigma, 'gaussian_sigma', float, 'matrix_nms') check_type(background_label, 'background_label', int, 'matrix_nms') - if in_dygraph_mode(): + if in_dynamic_mode(): attrs = ('background_label', background_label, 'score_threshold', score_threshold, 'post_threshold', post_threshold, 'nms_top_k', nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian', @@ -1208,111 +888,6 @@ def matrix_nms(bboxes, return output, rois_num, index -def bipartite_match(dist_matrix, - match_type=None, - dist_threshold=None, - name=None): - """ - - This operator implements a greedy bipartite matching algorithm, which is - used to obtain the matching with the maximum distance based on the input - distance matrix. For input 2D matrix, the bipartite matching algorithm can - find the matched column for each row (matched means the largest distance), - also can find the matched row for each column. And this operator only - calculate matched indices from column to row. For each instance, - the number of matched indices is the column number of the input distance - matrix. **The OP only supports CPU**. - - There are two outputs, matched indices and distance. - A simple description, this algorithm matched the best (maximum distance) - row entity to the column entity and the matched indices are not duplicated - in each row of ColToRowMatchIndices. If the column entity is not matched - any row entity, set -1 in ColToRowMatchIndices. - - NOTE: the input DistMat can be LoDTensor (with LoD) or Tensor. - If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. - If Tensor, the height of ColToRowMatchIndices is 1. - - NOTE: This API is a very low level API. It is used by :code:`ssd_loss` - layer. Please consider to use :code:`ssd_loss` instead. - - Args: - dist_matrix(Tensor): This input is a 2-D LoDTensor with shape - [K, M]. The data type is float32 or float64. It is pair-wise - distance matrix between the entities represented by each row and - each column. For example, assumed one entity is A with shape [K], - another entity is B with shape [M]. The dist_matrix[i][j] is the - distance between A[i] and B[j]. The bigger the distance is, the - better matching the pairs are. NOTE: This tensor can contain LoD - information to represent a batch of inputs. One instance of this - batch can contain different numbers of entities. - match_type(str, optional): The type of matching method, should be - 'bipartite' or 'per_prediction'. None ('bipartite') by default. - dist_threshold(float32, optional): If `match_type` is 'per_prediction', - this threshold is to determine the extra matching bboxes based - on the maximum distance, 0.5 by default. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - Tuple: - - matched_indices(Tensor): A 2-D Tensor with shape [N, M]. The data - type is int32. N is the batch size. If match_indices[i][j] is -1, it - means B[j] does not match any entity in i-th instance. - Otherwise, it means B[j] is matched to row - match_indices[i][j] in i-th instance. The row number of - i-th instance is saved in match_indices[i][j]. - - matched_distance(Tensor): A 2-D Tensor with shape [N, M]. The data - type is float32. N is batch size. If match_indices[i][j] is -1, - match_distance[i][j] is also -1.0. Otherwise, assumed - match_distance[i][j] = d, and the row offsets of each instance - are called LoD. Then match_distance[i][j] = - dist_matrix[d+LoD[i]][j]. - - Examples: - - .. code-block:: python - import paddle - from ppdet.modeling import ops - from ppdet.modeling.utils import iou_similarity - - paddle.enable_static() - - x = paddle.static.data(name='x', shape=[None, 4], dtype='float32') - y = paddle.static.data(name='y', shape=[None, 4], dtype='float32') - iou = iou_similarity(x=x, y=y) - matched_indices, matched_dist = ops.bipartite_match(iou) - """ - check_variable_and_dtype(dist_matrix, 'dist_matrix', - ['float32', 'float64'], 'bipartite_match') - - if in_dygraph_mode(): - match_indices, match_distance = _C_ops.bipartite_match( - dist_matrix, "match_type", match_type, "dist_threshold", - dist_threshold) - return match_indices, match_distance - - helper = LayerHelper('bipartite_match', **locals()) - match_indices = helper.create_variable_for_type_inference(dtype='int32') - match_distance = helper.create_variable_for_type_inference( - dtype=dist_matrix.dtype) - helper.append_op( - type='bipartite_match', - inputs={'DistMat': dist_matrix}, - attrs={ - 'match_type': match_type, - 'dist_threshold': dist_threshold, - }, - outputs={ - 'ColToRowMatchIndices': match_indices, - 'ColToRowMatchDist': match_distance - }) - return match_indices, match_distance - - @paddle.jit.not_to_static def box_coder(prior_box, prior_box_var, @@ -1425,7 +1000,7 @@ def box_coder(prior_box, check_variable_and_dtype(target_box, 'target_box', ['float32', 'float64'], 'box_coder') - if in_dygraph_mode(): + if in_dynamic_mode(): if isinstance(prior_box_var, Variable): output_box = _C_ops.box_coder( prior_box, prior_box_var, target_box, "code_type", code_type, @@ -1550,7 +1125,7 @@ def generate_proposals(scores, rois, roi_probs = ops.generate_proposals(scores, bbox_deltas, im_shape, anchors, variances) """ - if in_dygraph_mode(): + if in_dynamic_mode(): assert return_rois_num, "return_rois_num should be True in dygraph mode." attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta, @@ -1656,8 +1231,3 @@ def get_static_shape(tensor): shape = paddle.shape(tensor) shape.stop_gradient = True return shape - - -def paddle_distributed_is_initialized(): - return core.is_compiled_with_dist( - ) and parallel_helper._is_parallel_ctx_initialized() diff --git a/ppdet/modeling/tests/test_base.py b/ppdet/modeling/tests/test_base.py index cbb9033b393a24167ec1ebc32a4d924fa564f929..451aa78e32ce0682f55a2ab0f9d1ea03e939e481 100644 --- a/ppdet/modeling/tests/test_base.py +++ b/ppdet/modeling/tests/test_base.py @@ -18,9 +18,7 @@ import unittest import contextlib import paddle -import paddle.fluid as fluid -from paddle.fluid.framework import Program -from paddle.fluid import core +from paddle.static import Program class LayerTest(unittest.TestCase): @@ -35,19 +33,17 @@ class LayerTest(unittest.TestCase): def _get_place(self, force_to_use_cpu=False): # this option for ops that only have cpu kernel if force_to_use_cpu: - return core.CPUPlace() + return 'cpu' else: - if core.is_compiled_with_cuda(): - return core.CUDAPlace(0) - return core.CPUPlace() + return paddle.device.get_device() @contextlib.contextmanager def static_graph(self): paddle.enable_static() - scope = fluid.core.Scope() + scope = paddle.static.Scope() program = Program() - with fluid.scope_guard(scope): - with fluid.program_guard(program): + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(program): paddle.seed(self.seed) paddle.framework.random._manual_program_seed(self.seed) yield @@ -57,9 +53,9 @@ class LayerTest(unittest.TestCase): fetch_list, with_lod=False, force_to_use_cpu=False): - exe = fluid.Executor(self._get_place(force_to_use_cpu)) - exe.run(fluid.default_startup_program()) - return exe.run(fluid.default_main_program(), + exe = paddle.static.Executor(self._get_place(force_to_use_cpu)) + exe.run(paddle.static.default_startup_program()) + return exe.run(paddle.static.default_main_program(), feed=feed, fetch_list=fetch_list, return_numpy=(not with_lod)) @@ -67,8 +63,8 @@ class LayerTest(unittest.TestCase): @contextlib.contextmanager def dynamic_graph(self, force_to_use_cpu=False): paddle.disable_static() - with fluid.dygraph.guard( - self._get_place(force_to_use_cpu=force_to_use_cpu)): - paddle.seed(self.seed) - paddle.framework.random._manual_program_seed(self.seed) - yield + place = self._get_place(force_to_use_cpu=force_to_use_cpu) + paddle.device.set_device(place) + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + yield diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index d4b5747487d3ee49627e4fe8aecec31cf2759ae2..b83c0c7292a179cba78b360893f8d4c4bbf015a2 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -23,8 +23,6 @@ import unittest import numpy as np import paddle -import paddle.fluid as fluid -from paddle.fluid.dygraph import base import ppdet.modeling.ops as ops from ppdet.modeling.tests.test_base import LayerTest @@ -50,127 +48,6 @@ def softmax(x): return exps / np.sum(exps) -class TestCollectFpnProposals(LayerTest): - def test_collect_fpn_proposals(self): - multi_bboxes_np = [] - multi_scores_np = [] - rois_num_per_level_np = [] - for i in range(4): - bboxes_np = np.random.rand(5, 4).astype('float32') - scores_np = np.random.rand(5, 1).astype('float32') - rois_num = np.array([2, 3]).astype('int32') - multi_bboxes_np.append(bboxes_np) - multi_scores_np.append(scores_np) - rois_num_per_level_np.append(rois_num) - - with self.static_graph(): - multi_bboxes = [] - multi_scores = [] - rois_num_per_level = [] - for i in range(4): - bboxes = paddle.static.data( - name='rois' + str(i), - shape=[5, 4], - dtype='float32', - lod_level=1) - scores = paddle.static.data( - name='scores' + str(i), - shape=[5, 1], - dtype='float32', - lod_level=1) - rois_num = paddle.static.data( - name='rois_num' + str(i), shape=[None], dtype='int32') - - multi_bboxes.append(bboxes) - multi_scores.append(scores) - rois_num_per_level.append(rois_num) - - fpn_rois, rois_num = ops.collect_fpn_proposals( - multi_bboxes, - multi_scores, - 2, - 5, - 10, - rois_num_per_level=rois_num_per_level) - feed = {} - for i in range(4): - feed['rois' + str(i)] = multi_bboxes_np[i] - feed['scores' + str(i)] = multi_scores_np[i] - feed['rois_num' + str(i)] = rois_num_per_level_np[i] - fpn_rois_stat, rois_num_stat = self.get_static_graph_result( - feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True) - fpn_rois_stat = np.array(fpn_rois_stat) - rois_num_stat = np.array(rois_num_stat) - - with self.dynamic_graph(): - multi_bboxes_dy = [] - multi_scores_dy = [] - rois_num_per_level_dy = [] - for i in range(4): - bboxes_dy = base.to_variable(multi_bboxes_np[i]) - scores_dy = base.to_variable(multi_scores_np[i]) - rois_num_dy = base.to_variable(rois_num_per_level_np[i]) - multi_bboxes_dy.append(bboxes_dy) - multi_scores_dy.append(scores_dy) - rois_num_per_level_dy.append(rois_num_dy) - fpn_rois_dy, rois_num_dy = ops.collect_fpn_proposals( - multi_bboxes_dy, - multi_scores_dy, - 2, - 5, - 10, - rois_num_per_level=rois_num_per_level_dy) - fpn_rois_dy = fpn_rois_dy.numpy() - rois_num_dy = rois_num_dy.numpy() - - self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy)) - self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy)) - - def test_collect_fpn_proposals_error(self): - def generate_input(bbox_type, score_type, name): - multi_bboxes = [] - multi_scores = [] - for i in range(4): - bboxes = paddle.static.data( - name='rois' + name + str(i), - shape=[10, 4], - dtype=bbox_type, - lod_level=1) - scores = paddle.static.data( - name='scores' + name + str(i), - shape=[10, 1], - dtype=score_type, - lod_level=1) - multi_bboxes.append(bboxes) - multi_scores.append(scores) - return multi_bboxes, multi_scores - - with self.static_graph(): - bbox1 = paddle.static.data( - name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1) - score1 = paddle.static.data( - name='scores', shape=[5, 10, 1], dtype='float32', lod_level=1) - bbox2, score2 = generate_input('int32', 'float32', '2') - self.assertRaises( - TypeError, - ops.collect_fpn_proposals, - multi_rois=bbox1, - multi_scores=score1, - min_level=2, - max_level=5, - post_nms_top_n=2000) - self.assertRaises( - TypeError, - ops.collect_fpn_proposals, - multi_rois=bbox2, - multi_scores=score2, - min_level=2, - max_level=5, - post_nms_top_n=2000) - - paddle.disable_static() - - class TestDistributeFpnProposals(LayerTest): def test_distribute_fpn_proposals(self): rois_np = np.random.rand(10, 4).astype('float32') @@ -200,8 +77,8 @@ class TestDistributeFpnProposals(LayerTest): output_stat_np.append(output_np) with self.dynamic_graph(): - rois_dy = base.to_variable(rois_np) - rois_num_dy = base.to_variable(rois_num_np) + rois_dy = paddle.to_tensor(rois_np) + rois_num_dy = paddle.to_tensor(rois_num_np) multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = ops.distribute_fpn_proposals( fpn_rois=rois_dy, min_level=2, @@ -266,9 +143,9 @@ class TestROIAlign(LayerTest): with_lod=False) with self.dynamic_graph(): - inputs_dy = base.to_variable(inputs_np) - rois_dy = base.to_variable(rois_np) - rois_num_dy = base.to_variable(rois_num_np) + inputs_dy = paddle.to_tensor(inputs_np) + rois_dy = paddle.to_tensor(rois_np) + rois_num_dy = paddle.to_tensor(rois_num_np) output_dy = ops.roi_align( input=inputs_dy, @@ -326,9 +203,9 @@ class TestROIPool(LayerTest): with_lod=False) with self.dynamic_graph(): - inputs_dy = base.to_variable(inputs_np) - rois_dy = base.to_variable(rois_np) - rois_num_dy = base.to_variable(rois_num_np) + inputs_dy = paddle.to_tensor(inputs_np) + rois_dy = paddle.to_tensor(rois_np) + rois_num_dy = paddle.to_tensor(rois_num_np) output_dy, _ = ops.roi_pool( input=inputs_dy, @@ -355,134 +232,6 @@ class TestROIPool(LayerTest): paddle.disable_static() -class TestIoUSimilarity(LayerTest): - def test_iou_similarity(self): - b, c, h, w = 2, 12, 20, 20 - inputs_np = np.random.rand(b, c, h, w).astype('float32') - output_size = (7, 7) - x_np = make_rois(h, w, [20], output_size) - y_np = make_rois(h, w, [10], output_size) - with self.static_graph(): - x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') - y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') - - iou = ops.iou_similarity(x=x, y=y) - iou_np, = self.get_static_graph_result( - feed={ - 'x': x_np, - 'y': y_np, - }, fetch_list=[iou], with_lod=False) - - with self.dynamic_graph(): - x_dy = base.to_variable(x_np) - y_dy = base.to_variable(y_np) - - iou_dy = ops.iou_similarity(x=x_dy, y=y_dy) - iou_dy_np = iou_dy.numpy() - - self.assertTrue(np.array_equal(iou_np, iou_dy_np)) - - -class TestBipartiteMatch(LayerTest): - def test_bipartite_match(self): - distance = np.random.random((20, 10)).astype('float32') - with self.static_graph(): - x = paddle.static.data(name='x', shape=[20, 10], dtype='float32') - - match_indices, match_dist = ops.bipartite_match( - x, match_type='per_prediction', dist_threshold=0.5) - match_indices_np, match_dist_np = self.get_static_graph_result( - feed={'x': distance, }, - fetch_list=[match_indices, match_dist], - with_lod=False) - - with self.dynamic_graph(): - x_dy = base.to_variable(distance) - - match_indices_dy, match_dist_dy = ops.bipartite_match( - x_dy, match_type='per_prediction', dist_threshold=0.5) - match_indices_dy_np = match_indices_dy.numpy() - match_dist_dy_np = match_dist_dy.numpy() - - self.assertTrue(np.array_equal(match_indices_np, match_indices_dy_np)) - self.assertTrue(np.array_equal(match_dist_np, match_dist_dy_np)) - - -class TestYoloBox(LayerTest): - def test_yolo_box(self): - - # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 - np_x = np.random.random([1, 30, 7, 7]).astype('float32') - np_origin_shape = np.array([[608, 608]], dtype='int32') - class_num = 10 - conf_thresh = 0.01 - downsample_ratio = 32 - scale_x_y = 1.2 - - # static - with self.static_graph(): - # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 - x = paddle.static.data( - name='x', shape=[1, 30, 7, 7], dtype='float32') - origin_shape = paddle.static.data( - name='origin_shape', shape=[1, 2], dtype='int32') - - boxes, scores = ops.yolo_box( - x, - origin_shape, [10, 13, 30, 13], - class_num, - conf_thresh, - downsample_ratio, - scale_x_y=scale_x_y) - - boxes_np, scores_np = self.get_static_graph_result( - feed={ - 'x': np_x, - 'origin_shape': np_origin_shape, - }, - fetch_list=[boxes, scores], - with_lod=False) - - # dygraph - with self.dynamic_graph(): - x_dy = fluid.layers.assign(np_x) - origin_shape_dy = fluid.layers.assign(np_origin_shape) - - boxes_dy, scores_dy = ops.yolo_box( - x_dy, - origin_shape_dy, [10, 13, 30, 13], - 10, - 0.01, - 32, - scale_x_y=scale_x_y) - - boxes_dy_np = boxes_dy.numpy() - scores_dy_np = scores_dy.numpy() - - self.assertTrue(np.array_equal(boxes_np, boxes_dy_np)) - self.assertTrue(np.array_equal(scores_np, scores_dy_np)) - - def test_yolo_box_error(self): - with self.static_graph(): - # x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2 - x = paddle.static.data( - name='x', shape=[1, 30, 7, 7], dtype='float32') - origin_shape = paddle.static.data( - name='origin_shape', shape=[1, 2], dtype='int32') - - self.assertRaises( - TypeError, - ops.yolo_box, - x, - origin_shape, [10, 13, 30, 13], - 10.123, - 0.01, - 32, - scale_x_y=1.2) - - paddle.disable_static() - - class TestPriorBox(LayerTest): def test_prior_box(self): input_np = np.random.rand(2, 10, 32, 32).astype('float32') @@ -509,8 +258,8 @@ class TestPriorBox(LayerTest): with_lod=False) with self.dynamic_graph(): - inputs_dy = base.to_variable(input_np) - image_dy = base.to_variable(image_np) + inputs_dy = paddle.to_tensor(input_np) + image_dy = paddle.to_tensor(image_np) box_dy, var_dy = ops.prior_box( input=inputs_dy, @@ -582,9 +331,9 @@ class TestMulticlassNms(LayerTest): nms_rois_num_np = np.array(nms_rois_num_np) with self.dynamic_graph(): - boxes_dy = base.to_variable(boxes_np) - scores_dy = base.to_variable(scores_np) - rois_num_dy = base.to_variable(rois_num_np) + boxes_dy = paddle.to_tensor(boxes_np) + scores_dy = paddle.to_tensor(scores_np) + rois_num_dy = paddle.to_tensor(rois_num_np) out_dy, index_dy, nms_rois_num_dy = ops.multiclass_nms( bboxes=boxes_dy, @@ -666,8 +415,8 @@ class TestMatrixNMS(LayerTest): with_lod=True) with self.dynamic_graph(): - boxes_dy = base.to_variable(boxes_np) - scores_dy = base.to_variable(scores_np) + boxes_dy = paddle.to_tensor(boxes_np) + scores_dy = paddle.to_tensor(scores_np) out_dy, index_dy, _ = ops.matrix_nms( bboxes=boxes_dy, @@ -737,9 +486,9 @@ class TestBoxCoder(LayerTest): # dygraph with self.dynamic_graph(): - prior_box_dy = base.to_variable(prior_box_np) - prior_box_var_dy = base.to_variable(prior_box_var_np) - target_box_dy = base.to_variable(target_box_np) + prior_box_dy = paddle.to_tensor(prior_box_np) + prior_box_var_dy = paddle.to_tensor(prior_box_var_np) + target_box_dy = paddle.to_tensor(target_box_np) boxes_dy = ops.box_coder( prior_box=prior_box_dy, @@ -808,11 +557,11 @@ class TestGenerateProposals(LayerTest): with_lod=True) with self.dynamic_graph(): - scores_dy = base.to_variable(scores_np) - bbox_deltas_dy = base.to_variable(bbox_deltas_np) - im_shape_dy = base.to_variable(im_shape_np) - anchors_dy = base.to_variable(anchors_np) - variances_dy = base.to_variable(variances_np) + scores_dy = paddle.to_tensor(scores_np) + bbox_deltas_dy = paddle.to_tensor(bbox_deltas_np) + im_shape_dy = paddle.to_tensor(im_shape_np) + anchors_dy = paddle.to_tensor(anchors_np) + variances_dy = paddle.to_tensor(variances_np) rois, roi_probs, rois_num = ops.generate_proposals( scores_dy, bbox_deltas_dy, diff --git a/ppdet/modeling/tests/test_yolov3_loss.py b/ppdet/modeling/tests/test_yolov3_loss.py index cec8bc940a4abb852d0b210b76ffe4386b8fc12e..433b3cf2cb95c2a1dd27da30ef9b99f3148e004f 100644 --- a/ppdet/modeling/tests/test_yolov3_loss.py +++ b/ppdet/modeling/tests/test_yolov3_loss.py @@ -17,7 +17,7 @@ from __future__ import division import unittest import paddle -from paddle import fluid +import paddle.nn.functional as F # add python path of PadleDetection to sys.path import os import sys @@ -27,19 +27,9 @@ if parent_path not in sys.path: from ppdet.modeling.losses import YOLOv3Loss from ppdet.data.transform.op_helper import jaccard_overlap +from ppdet.modeling.bbox_utils import iou_similarity import numpy as np - - -def _split_ioup(output, an_num, num_classes): - """ - Split output feature map to output, predicted iou - along channel dimension - """ - ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num]) - ioup = fluid.layers.sigmoid(ioup) - oriout = fluid.layers.slice( - output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)]) - return (ioup, oriout) +np.random.seed(0) def _split_output(output, an_num, num_classes): @@ -47,31 +37,31 @@ def _split_output(output, an_num, num_classes): Split output feature map to x, y, w, h, objectness, classification along channel dimension """ - x = fluid.layers.strided_slice( + x = paddle.strided_slice( output, axes=[1], starts=[0], ends=[output.shape[1]], strides=[5 + num_classes]) - y = fluid.layers.strided_slice( + y = paddle.strided_slice( output, axes=[1], starts=[1], ends=[output.shape[1]], strides=[5 + num_classes]) - w = fluid.layers.strided_slice( + w = paddle.strided_slice( output, axes=[1], starts=[2], ends=[output.shape[1]], strides=[5 + num_classes]) - h = fluid.layers.strided_slice( + h = paddle.strided_slice( output, axes=[1], starts=[3], ends=[output.shape[1]], strides=[5 + num_classes]) - obj = fluid.layers.strided_slice( + obj = paddle.strided_slice( output, axes=[1], starts=[4], @@ -81,14 +71,12 @@ def _split_output(output, an_num, num_classes): stride = output.shape[1] // an_num for m in range(an_num): clss.append( - fluid.layers.slice( + paddle.slice( output, axes=[1], starts=[stride * m + 5], ends=[stride * m + 5 + num_classes])) - cls = fluid.layers.transpose( - fluid.layers.stack( - clss, axis=1), perm=[0, 1, 3, 4, 2]) + cls = paddle.transpose(paddle.stack(clss, axis=1), perm=[0, 1, 3, 4, 2]) return (x, y, w, h, obj, cls) @@ -104,7 +92,7 @@ def _split_target(target): th = target[:, :, 3, :, :] tscale = target[:, :, 4, :, :] tobj = target[:, :, 5, :, :] - tcls = fluid.layers.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2]) + tcls = paddle.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2]) tcls.stop_gradient = True return (tx, ty, tw, th, tscale, tobj, tcls) @@ -115,9 +103,9 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes, # objectness loss will be ignored, process as follows: # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here # NOTE: img_size is set as 1.0 to get noramlized pred bbox - bbox, prob = fluid.layers.yolo_box( + bbox, prob = paddle.vision.ops.yolo_box( x=output, - img_size=fluid.layers.ones( + img_size=paddle.ones( shape=[batch_size, 2], dtype="int32"), anchors=anchors, class_num=num_classes, @@ -128,8 +116,8 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes, # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox # and gt bbox in each sample if batch_size > 1: - preds = fluid.layers.split(bbox, batch_size, dim=0) - gts = fluid.layers.split(gt_box, batch_size, dim=0) + preds = paddle.split(bbox, batch_size, axis=0) + gts = paddle.split(gt_box, batch_size, axis=0) else: preds = [bbox] gts = [gt_box] @@ -142,7 +130,7 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes, y = box[:, 1] w = box[:, 2] h = box[:, 3] - return fluid.layers.stack( + return paddle.stack( [ x - w / 2., y - h / 2., @@ -150,28 +138,29 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes, y + h / 2., ], axis=1) - pred = fluid.layers.squeeze(pred, axes=[0]) - gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) - ious.append(fluid.layers.iou_similarity(pred, gt)) - iou = fluid.layers.stack(ious, axis=0) + pred = paddle.squeeze(pred, axis=[0]) + gt = box_xywh2xyxy(paddle.squeeze(gt, axis=[0])) + ious.append(iou_similarity(pred, gt)) + iou = paddle.stack(ious, axis=0) # 3. Get iou_mask by IoU between gt bbox and prediction bbox, # Get obj_mask by tobj(holds gt_score), calculate objectness loss - max_iou = fluid.layers.reduce_max(iou, dim=-1) - iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") - output_shape = fluid.layers.shape(output) + max_iou = paddle.max(iou, axis=-1) + iou_mask = paddle.cast(max_iou <= ignore_thresh, dtype="float32") + output_shape = paddle.shape(output) an_num = len(anchors) // 2 - iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], - output_shape[3])) + iou_mask = paddle.reshape(iou_mask, (-1, an_num, output_shape[2], + output_shape[3])) iou_mask.stop_gradient = True # NOTE: tobj holds gt_score, obj_mask holds object existence mask - obj_mask = fluid.layers.cast(tobj > 0., dtype="float32") + obj_mask = paddle.cast(tobj > 0., dtype="float32") obj_mask.stop_gradient = True # For positive objectness grids, objectness loss should be calculated # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0 - loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask) - loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3]) - loss_obj_neg = fluid.layers.reduce_sum( - loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3]) + obj_sigmoid = F.sigmoid(obj) + loss_obj = F.binary_cross_entropy(obj_sigmoid, obj_mask, reduction='none') + loss_obj_pos = paddle.sum(loss_obj * tobj, axis=[1, 2, 3]) + loss_obj_neg = paddle.sum(loss_obj * (1.0 - obj_mask) * iou_mask, + axis=[1, 2, 3]) return loss_obj_pos, loss_obj_neg @@ -194,45 +183,48 @@ def fine_grained_loss(output, scale_x_y = scale_x_y if (abs(scale_x_y - 1.0) < eps): - loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( - x, tx) * tscale_tobj - loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) - loss_y = fluid.layers.sigmoid_cross_entropy_with_logits( - y, ty) * tscale_tobj - loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + x = F.sigmoid(x) + y = F.sigmoid(y) + loss_x = F.binary_cross_entropy(x, tx, reduction='none') * tscale_tobj + loss_x = paddle.sum(loss_x, axis=[1, 2, 3]) + loss_y = F.binary_cross_entropy(y, ty, reduction='none') * tscale_tobj + loss_y = paddle.sum(loss_y, axis=[1, 2, 3]) else: - dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0) - dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0) - loss_x = fluid.layers.abs(dx - tx) * tscale_tobj - loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) - loss_y = fluid.layers.abs(dy - ty) * tscale_tobj - loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + dx = scale_x_y * F.sigmoid(x) - 0.5 * (scale_x_y - 1.0) + dy = scale_x_y * F.sigmoid(y) - 0.5 * (scale_x_y - 1.0) + loss_x = paddle.abs(dx - tx) * tscale_tobj + loss_x = paddle.sum(loss_x, axis=[1, 2, 3]) + loss_y = paddle.abs(dy - ty) * tscale_tobj + loss_y = paddle.sum(loss_y, axis=[1, 2, 3]) # NOTE: we refined loss function of (w, h) as L1Loss - loss_w = fluid.layers.abs(w - tw) * tscale_tobj - loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) - loss_h = fluid.layers.abs(h - th) * tscale_tobj - loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + loss_w = paddle.abs(w - tw) * tscale_tobj + loss_w = paddle.sum(loss_w, axis=[1, 2, 3]) + loss_h = paddle.abs(h - th) * tscale_tobj + loss_h = paddle.sum(loss_h, axis=[1, 2, 3]) loss_obj_pos, loss_obj_neg = _calc_obj_loss( output, obj, tobj, gt_box, batch_size, anchors, num_classes, downsample, ignore_thresh, scale_x_y) - loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) - loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) - loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4]) + cls = F.sigmoid(cls) + loss_cls = F.binary_cross_entropy(cls, tcls, reduction='none') + tobj = paddle.unsqueeze(tobj, axis=-1) + + loss_cls = paddle.multiply(loss_cls, tobj) + loss_cls = paddle.sum(loss_cls, axis=[1, 2, 3, 4]) - loss_xys = fluid.layers.reduce_mean(loss_x + loss_y) - loss_whs = fluid.layers.reduce_mean(loss_w + loss_h) - loss_objs = fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg) - loss_clss = fluid.layers.reduce_mean(loss_cls) + loss_xys = paddle.mean(loss_x + loss_y) + loss_whs = paddle.mean(loss_w + loss_h) + loss_objs = paddle.mean(loss_obj_pos + loss_obj_neg) + loss_clss = paddle.mean(loss_cls) losses_all = { - "loss_xy": fluid.layers.sum(loss_xys), - "loss_wh": fluid.layers.sum(loss_whs), - "loss_loc": fluid.layers.sum(loss_xys) + fluid.layers.sum(loss_whs), - "loss_obj": fluid.layers.sum(loss_objs), - "loss_cls": fluid.layers.sum(loss_clss), + "loss_xy": paddle.sum(loss_xys), + "loss_wh": paddle.sum(loss_whs), + "loss_loc": paddle.sum(loss_xys) + paddle.sum(loss_whs), + "loss_obj": paddle.sum(loss_objs), + "loss_cls": paddle.sum(loss_clss), } return losses_all, x, y, tx, ty diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index 45da8857e50e59fafde91b6875eff738a9a7a286..58c48806c82d8616d65f02fec80725dadd45d52b 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -20,7 +20,7 @@ import sys import paddle import six -import paddle.version as fluid_version +import paddle.version as paddle_version from .logger import setup_logger logger = setup_logger(__name__) @@ -97,8 +97,8 @@ def check_version(version='2.0'): "Please make sure the version is good with your code.".format(version) version_installed = [ - fluid_version.major, fluid_version.minor, fluid_version.patch, - fluid_version.rc + paddle_version.major, paddle_version.minor, paddle_version.patch, + paddle_version.rc ] if version_installed == ['0', '0', '0', '0']: return diff --git a/static/deploy/android_demo/app/src/main/cpp/Native.cc b/static/deploy/android_demo/app/src/main/cpp/Native.cc index 1b2700a91c8b9bd2b6a186378b6bdc068e8927a9..e35c77209b2455cc23af7842aa60b7f9ee605869 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Native.cc +++ b/static/deploy/android_demo/app/src/main/cpp/Native.cc @@ -26,17 +26,31 @@ extern "C" { */ JNIEXPORT jlong JNICALL Java_com_baidu_paddledetection_detection_Native_nativeInit( - JNIEnv *env, jclass thiz, jstring jModelDir, jstring jLabelPath, - jint cpuThreadNum, jstring jCPUPowerMode, jint inputWidth, jint inputHeight, - jfloatArray jInputMean, jfloatArray jInputStd, jfloat scoreThreshold) { + JNIEnv *env, + jclass thiz, + jstring jModelDir, + jstring jLabelPath, + jint cpuThreadNum, + jstring jCPUPowerMode, + jint inputWidth, + jint inputHeight, + jfloatArray jInputMean, + jfloatArray jInputStd, + jfloat scoreThreshold) { std::string modelDir = jstring_to_cpp_string(env, jModelDir); std::string labelPath = jstring_to_cpp_string(env, jLabelPath); std::string cpuPowerMode = jstring_to_cpp_string(env, jCPUPowerMode); std::vector inputMean = jfloatarray_to_float_vector(env, jInputMean); std::vector inputStd = jfloatarray_to_float_vector(env, jInputStd); - return reinterpret_cast( - new Pipeline(modelDir, labelPath, cpuThreadNum, cpuPowerMode, inputWidth, - inputHeight, inputMean, inputStd, scoreThreshold)); + return reinterpret_cast(new Pipeline(modelDir, + labelPath, + cpuThreadNum, + cpuPowerMode, + inputWidth, + inputHeight, + inputMean, + inputStd, + scoreThreshold)); } /* @@ -45,8 +59,9 @@ Java_com_baidu_paddledetection_detection_Native_nativeInit( * Signature: (J)Z */ JNIEXPORT jboolean JNICALL -Java_com_baidu_paddledetection_detection_Native_nativeRelease( - JNIEnv *env, jclass thiz, jlong ctx) { +Java_com_baidu_paddledetection_detection_Native_nativeRelease(JNIEnv *env, + jclass thiz, + jlong ctx) { if (ctx == 0) { return JNI_FALSE; } @@ -62,15 +77,21 @@ Java_com_baidu_paddledetection_detection_Native_nativeRelease( */ JNIEXPORT jboolean JNICALL Java_com_baidu_paddledetection_detection_Native_nativeProcess( - JNIEnv *env, jclass thiz, jlong ctx, jint inTextureId, jint outTextureId, - jint textureWidth, jint textureHeight, jstring jsavedImagePath) { + JNIEnv *env, + jclass thiz, + jlong ctx, + jint inTextureId, + jint outTextureId, + jint textureWidth, + jint textureHeight, + jstring jsavedImagePath) { if (ctx == 0) { return JNI_FALSE; } std::string savedImagePath = jstring_to_cpp_string(env, jsavedImagePath); Pipeline *pipeline = reinterpret_cast(ctx); - return pipeline->Process(inTextureId, outTextureId, textureWidth, - textureHeight, savedImagePath); + return pipeline->Process( + inTextureId, outTextureId, textureWidth, textureHeight, savedImagePath); } #ifdef __cplusplus diff --git a/static/deploy/android_demo/app/src/main/cpp/Native.h b/static/deploy/android_demo/app/src/main/cpp/Native.h index d595b8ea2c11c1b0a9219119abcc5678551c318f..c9bf0d9d46e370529e769ed6ab9986787f9f8715 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Native.h +++ b/static/deploy/android_demo/app/src/main/cpp/Native.h @@ -50,8 +50,8 @@ inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) { env->GetMethodID(strClass, "", "([BLjava/lang/String;)V"); jbyteArray bytes = env->NewByteArray(strlen(data)); - env->SetByteArrayRegion(bytes, 0, strlen(data), - reinterpret_cast(data)); + env->SetByteArrayRegion( + bytes, 0, strlen(data), reinterpret_cast(data)); jstring encoding = env->NewStringUTF("UTF-8"); jstring res = (jstring)( @@ -64,21 +64,24 @@ inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) { return res; } -inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf, +inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, + const float *buf, int64_t len) { jfloatArray result = env->NewFloatArray(len); env->SetFloatArrayRegion(result, 0, len, buf); return result; } -inline jintArray cpp_array_to_jintarray(JNIEnv *env, const int *buf, +inline jintArray cpp_array_to_jintarray(JNIEnv *env, + const int *buf, int64_t len) { jintArray result = env->NewIntArray(len); env->SetIntArrayRegion(result, 0, len, buf); return result; } -inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, const int8_t *buf, +inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, + const int8_t *buf, int64_t len) { jbyteArray result = env->NewByteArray(len); env->SetByteArrayRegion(result, 0, len, buf); diff --git a/static/deploy/android_demo/app/src/main/cpp/Pipeline.cc b/static/deploy/android_demo/app/src/main/cpp/Pipeline.cc index b3e3476d961cfd608186f9672786770b39da3268..da87a4d228e19162a866d2b090e6ddc9803e4026 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Pipeline.cc +++ b/static/deploy/android_demo/app/src/main/cpp/Pipeline.cc @@ -14,13 +14,20 @@ #include "Pipeline.h" -Detector::Detector(const std::string &modelDir, const std::string &labelPath, - const int cpuThreadNum, const std::string &cpuPowerMode, - int inputWidth, int inputHeight, +Detector::Detector(const std::string &modelDir, + const std::string &labelPath, + const int cpuThreadNum, + const std::string &cpuPowerMode, + int inputWidth, + int inputHeight, const std::vector &inputMean, - const std::vector &inputStd, float scoreThreshold) - : inputWidth_(inputWidth), inputHeight_(inputHeight), inputMean_(inputMean), - inputStd_(inputStd), scoreThreshold_(scoreThreshold) { + const std::vector &inputStd, + float scoreThreshold) + : inputWidth_(inputWidth), + inputHeight_(inputHeight), + inputMean_(inputMean), + inputStd_(inputStd), + scoreThreshold_(scoreThreshold) { paddle::lite_api::MobileConfig config; config.set_model_from_file(modelDir + "/model.nb"); config.set_threads(cpuThreadNum); @@ -71,13 +78,16 @@ void Detector::Preprocess(const cv::Mat &rgbaImage) { inputTensor->Resize(inputShape); auto inputData = inputTensor->mutable_data(); cv::Mat resizedRGBAImage; - cv::resize(rgbaImage, resizedRGBAImage, - cv::Size(inputShape[3], inputShape[2])); + cv::resize( + rgbaImage, resizedRGBAImage, cv::Size(inputShape[3], inputShape[2])); cv::Mat resizedRGBImage; cv::cvtColor(resizedRGBAImage, resizedRGBImage, cv::COLOR_BGRA2RGB); resizedRGBImage.convertTo(resizedRGBImage, CV_32FC3, 1.0 / 255.0f); - NHWC3ToNC3HW(reinterpret_cast(resizedRGBImage.data), inputData, - inputMean_.data(), inputStd_.data(), inputShape[3], + NHWC3ToNC3HW(reinterpret_cast(resizedRGBImage.data), + inputData, + inputMean_.data(), + inputStd_.data(), + inputShape[3], inputShape[2]); // Set the size of input image auto sizeTensor = predictor_->GetInput(1); @@ -97,8 +107,7 @@ void Detector::Postprocess(std::vector *results) { auto class_id = static_cast(round(outputData[i])); // Confidence score auto score = outputData[i + 1]; - if (score < scoreThreshold_) - continue; + if (score < scoreThreshold_) continue; RESULT object; object.class_name = class_id >= 0 && class_id < labelList_.size() ? labelList_[class_id] @@ -115,8 +124,10 @@ void Detector::Postprocess(std::vector *results) { } } -void Detector::Predict(const cv::Mat &rgbaImage, std::vector *results, - double *preprocessTime, double *predictTime, +void Detector::Predict(const cv::Mat &rgbaImage, + std::vector *results, + double *preprocessTime, + double *predictTime, double *postprocessTime) { auto t = GetCurrentTime(); @@ -136,13 +147,23 @@ void Detector::Predict(const cv::Mat &rgbaImage, std::vector *results, LOGD("Detector postprocess costs %f ms", *postprocessTime); } -Pipeline::Pipeline(const std::string &modelDir, const std::string &labelPath, - const int cpuThreadNum, const std::string &cpuPowerMode, - int inputWidth, int inputHeight, +Pipeline::Pipeline(const std::string &modelDir, + const std::string &labelPath, + const int cpuThreadNum, + const std::string &cpuPowerMode, + int inputWidth, + int inputHeight, const std::vector &inputMean, - const std::vector &inputStd, float scoreThreshold) { - detector_.reset(new Detector(modelDir, labelPath, cpuThreadNum, cpuPowerMode, - inputWidth, inputHeight, inputMean, inputStd, + const std::vector &inputStd, + float scoreThreshold) { + detector_.reset(new Detector(modelDir, + labelPath, + cpuThreadNum, + cpuPowerMode, + inputWidth, + inputHeight, + inputMean, + inputStd, scoreThreshold)); } @@ -169,15 +190,24 @@ void Pipeline::VisualizeResults(const std::vector &results, cv::Point2d(boundingBox.x, boundingBox.y - round(textSize.height * 1.25f)), cv::Point2d(boundingBox.x + boundingBox.width, boundingBox.y), - object.fill_color, -1); - cv::putText(*rgbaImage, text, cv::Point2d(boundingBox.x, boundingBox.y), - fontFace, fontScale, cv::Scalar(255, 255, 255), fontThickness); + object.fill_color, + -1); + cv::putText(*rgbaImage, + text, + cv::Point2d(boundingBox.x, boundingBox.y), + fontFace, + fontScale, + cv::Scalar(255, 255, 255), + fontThickness); } } -void Pipeline::VisualizeStatus(double readGLFBOTime, double writeGLTextureTime, - double preprocessTime, double predictTime, - double postprocessTime, cv::Mat *rgbaImage) { +void Pipeline::VisualizeStatus(double readGLFBOTime, + double writeGLTextureTime, + double preprocessTime, + double predictTime, + double postprocessTime, + cv::Mat *rgbaImage) { char text[255]; cv::Scalar fontColor = cv::Scalar(255, 255, 255); int fontFace = cv::FONT_HERSHEY_PLAIN; @@ -188,47 +218,54 @@ void Pipeline::VisualizeStatus(double readGLFBOTime, double writeGLTextureTime, cv::getTextSize(text, fontFace, fontScale, fontThickness, nullptr); textSize.height *= 1.25f; cv::Point2d offset(10, textSize.height + 15); - cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor, - fontThickness); + cv::putText( + *rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness); sprintf(text, "Write GLTexture time: %.1f ms", writeGLTextureTime); offset.y += textSize.height; - cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor, - fontThickness); + cv::putText( + *rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness); sprintf(text, "Preprocess time: %.1f ms", preprocessTime); offset.y += textSize.height; - cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor, - fontThickness); + cv::putText( + *rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness); sprintf(text, "Predict time: %.1f ms", predictTime); offset.y += textSize.height; - cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor, - fontThickness); + cv::putText( + *rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness); sprintf(text, "Postprocess time: %.1f ms", postprocessTime); offset.y += textSize.height; - cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor, - fontThickness); + cv::putText( + *rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness); } -bool Pipeline::Process(int inTexureId, int outTextureId, int textureWidth, - int textureHeight, std::string savedImagePath) { +bool Pipeline::Process(int inTexureId, + int outTextureId, + int textureWidth, + int textureHeight, + std::string savedImagePath) { static double readGLFBOTime = 0, writeGLTextureTime = 0; double preprocessTime = 0, predictTime = 0, postprocessTime = 0; // Read pixels from FBO texture to CV image cv::Mat rgbaImage; - CreateRGBAImageFromGLFBOTexture(textureWidth, textureHeight, &rgbaImage, - &readGLFBOTime); + CreateRGBAImageFromGLFBOTexture( + textureWidth, textureHeight, &rgbaImage, &readGLFBOTime); // Feed the image, run inference and parse the results std::vector results; - detector_->Predict(rgbaImage, &results, &preprocessTime, &predictTime, - &postprocessTime); + detector_->Predict( + rgbaImage, &results, &preprocessTime, &predictTime, &postprocessTime); // Visualize the objects to the origin image VisualizeResults(results, &rgbaImage); // Visualize the status(performance data) to the origin image - VisualizeStatus(readGLFBOTime, writeGLTextureTime, preprocessTime, - predictTime, postprocessTime, &rgbaImage); + VisualizeStatus(readGLFBOTime, + writeGLTextureTime, + preprocessTime, + predictTime, + postprocessTime, + &rgbaImage); // Dump modified image if savedImagePath is set if (!savedImagePath.empty()) { diff --git a/static/deploy/android_demo/app/src/main/cpp/Pipeline.h b/static/deploy/android_demo/app/src/main/cpp/Pipeline.h index 91177d0417814cd60b01112674baf6387675f9a0..7033bdea45164367562c27c724aa6c47d68daef8 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Pipeline.h +++ b/static/deploy/android_demo/app/src/main/cpp/Pipeline.h @@ -14,8 +14,6 @@ #pragma once -#include "Utils.h" -#include "paddle_api.h" #include #include #include @@ -24,6 +22,8 @@ #include #include #include +#include "Utils.h" +#include "paddle_api.h" struct RESULT { std::string class_name; @@ -36,24 +36,30 @@ struct RESULT { }; class Detector { -public: - explicit Detector(const std::string &modelDir, const std::string &labelPath, - const int cpuThreadNum, const std::string &cpuPowerMode, - int inputWidth, int inputHeight, + public: + explicit Detector(const std::string &modelDir, + const std::string &labelPath, + const int cpuThreadNum, + const std::string &cpuPowerMode, + int inputWidth, + int inputHeight, const std::vector &inputMean, - const std::vector &inputStd, float scoreThreshold); + const std::vector &inputStd, + float scoreThreshold); - void Predict(const cv::Mat &rgbImage, std::vector *results, - double *preprocessTime, double *predictTime, + void Predict(const cv::Mat &rgbImage, + std::vector *results, + double *preprocessTime, + double *predictTime, double *postprocessTime); -private: + private: std::vector LoadLabelList(const std::string &path); std::vector GenerateColorMap(int numOfClasses); void Preprocess(const cv::Mat &rgbaImage); void Postprocess(std::vector *results); -private: + private: int inputWidth_; int inputHeight_; std::vector inputMean_; @@ -65,36 +71,58 @@ private: }; class Pipeline { -public: - Pipeline(const std::string &modelDir, const std::string &labelPath, - const int cpuThreadNum, const std::string &cpuPowerMode, - int inputWidth, int inputHeight, const std::vector &inputMean, - const std::vector &inputStd, float scoreThreshold); + public: + Pipeline(const std::string &modelDir, + const std::string &labelPath, + const int cpuThreadNum, + const std::string &cpuPowerMode, + int inputWidth, + int inputHeight, + const std::vector &inputMean, + const std::vector &inputStd, + float scoreThreshold); - bool Process(int inTextureId, int outTextureId, int textureWidth, - int textureHeight, std::string savedImagePath); + bool Process(int inTextureId, + int outTextureId, + int textureWidth, + int textureHeight, + std::string savedImagePath); -private: + private: // Read pixels from FBO texture to CV image - void CreateRGBAImageFromGLFBOTexture(int textureWidth, int textureHeight, + void CreateRGBAImageFromGLFBOTexture(int textureWidth, + int textureHeight, cv::Mat *rgbaImage, double *readGLFBOTime) { auto t = GetCurrentTime(); rgbaImage->create(textureHeight, textureWidth, CV_8UC4); - glReadPixels(0, 0, textureWidth, textureHeight, GL_RGBA, GL_UNSIGNED_BYTE, + glReadPixels(0, + 0, + textureWidth, + textureHeight, + GL_RGBA, + GL_UNSIGNED_BYTE, rgbaImage->data); *readGLFBOTime = GetElapsedTime(t); LOGD("Read from FBO texture costs %f ms", *readGLFBOTime); } // Write back to texture2D - void WriteRGBAImageBackToGLTexture(const cv::Mat &rgbaImage, int textureId, + void WriteRGBAImageBackToGLTexture(const cv::Mat &rgbaImage, + int textureId, double *writeGLTextureTime) { auto t = GetCurrentTime(); glActiveTexture(GL_TEXTURE0); glBindTexture(GL_TEXTURE_2D, textureId); - glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, rgbaImage.cols, rgbaImage.rows, - GL_RGBA, GL_UNSIGNED_BYTE, rgbaImage.data); + glTexSubImage2D(GL_TEXTURE_2D, + 0, + 0, + 0, + rgbaImage.cols, + rgbaImage.rows, + GL_RGBA, + GL_UNSIGNED_BYTE, + rgbaImage.data); *writeGLTextureTime = GetElapsedTime(t); LOGD("Write back to texture2D costs %f ms", *writeGLTextureTime); } @@ -103,10 +131,13 @@ private: void VisualizeResults(const std::vector &results, cv::Mat *rgbaImage); // Visualize the status(performace data) to origin image - void VisualizeStatus(double readGLFBOTime, double writeGLTextureTime, - double preprocessTime, double predictTime, - double postprocessTime, cv::Mat *rgbaImage); + void VisualizeStatus(double readGLFBOTime, + double writeGLTextureTime, + double preprocessTime, + double predictTime, + double postprocessTime, + cv::Mat *rgbaImage); -private: + private: std::shared_ptr detector_; }; diff --git a/static/deploy/android_demo/app/src/main/cpp/Utils.cc b/static/deploy/android_demo/app/src/main/cpp/Utils.cc index 63ea54fd4b2de0e9822f5c392606c3e7dca1da0b..9157e93741e8b096a3b07d19f225376f745e8ee4 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Utils.cc +++ b/static/deploy/android_demo/app/src/main/cpp/Utils.cc @@ -17,13 +17,16 @@ int64_t ShapeProduction(const std::vector &shape) { int64_t res = 1; - for (auto i : shape) - res *= i; + for (auto i : shape) res *= i; return res; } -void NHWC3ToNC3HW(const float *src, float *dst, const float *mean, - const float *std, int width, int height) { +void NHWC3ToNC3HW(const float *src, + float *dst, + const float *mean, + const float *std, + int width, + int height) { int size = height * width; float32x4_t vmean0 = vdupq_n_f32(mean ? mean[0] : 0.0f); float32x4_t vmean1 = vdupq_n_f32(mean ? mean[1] : 0.0f); @@ -58,8 +61,12 @@ void NHWC3ToNC3HW(const float *src, float *dst, const float *mean, } } -void NHWC1ToNC1HW(const float *src, float *dst, const float *mean, - const float *std, int width, int height) { +void NHWC1ToNC1HW(const float *src, + float *dst, + const float *mean, + const float *std, + int width, + int height) { int size = height * width; float32x4_t vmean = vdupq_n_f32(mean ? mean[0] : 0.0f); float32x4_t vscale = vdupq_n_f32(std ? (1.0f / std[0]) : 1.0f); diff --git a/static/deploy/android_demo/app/src/main/cpp/Utils.h b/static/deploy/android_demo/app/src/main/cpp/Utils.h index 74fa82a6423dd600f3aa711da4ca244d0d4c34eb..40254a9ce19649bfdd9b0a40445405001c2fcc95 100644 --- a/static/deploy/android_demo/app/src/main/cpp/Utils.h +++ b/static/deploy/android_demo/app/src/main/cpp/Utils.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle_api.h" #include #include #include #include +#include "paddle_api.h" #define TAG "JNI" #define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) @@ -85,8 +85,16 @@ inline paddle::lite_api::PowerMode ParsePowerMode(std::string mode) { return paddle::lite_api::LITE_POWER_NO_BIND; } -void NHWC3ToNC3HW(const float *src, float *dst, const float *mean, - const float *std, int width, int height); +void NHWC3ToNC3HW(const float *src, + float *dst, + const float *mean, + const float *std, + int width, + int height); -void NHWC1ToNC1HW(const float *src, float *dst, const float *mean, - const float *std, int width, int height); +void NHWC1ToNC1HW(const float *src, + float *dst, + const float *mean, + const float *std, + int width, + int height); diff --git a/static/deploy/cpp/include/config_parser.h b/static/deploy/cpp/include/config_parser.h index f2102dcc9f2902d319790ebae705a6d3fa3a4993..b6af7be1ede5f16770cd9438bf1499d0523cdaa5 100644 --- a/static/deploy/cpp/include/config_parser.h +++ b/static/deploy/cpp/include/config_parser.h @@ -15,9 +15,9 @@ #pragma once #include -#include -#include #include +#include +#include #include "yaml-cpp/yaml.h" @@ -47,8 +47,7 @@ class ConfigPaser { mode_ = config["mode"].as(); } else { std::cerr << "Please set mode, " - << "support value : fluid/trt_fp16/trt_fp32." - << std::endl; + << "support value : fluid/trt_fp16/trt_fp32." << std::endl; return false; } @@ -110,4 +109,3 @@ class ConfigPaser { }; } // namespace PaddleDetection - diff --git a/static/deploy/cpp/include/object_detector.h b/static/deploy/cpp/include/object_detector.h index b0173989dd80782f1243dce9250ca5f7aee634c5..fbe2541cdfee9624fc7948f6a21563eced712146 100644 --- a/static/deploy/cpp/include/object_detector.h +++ b/static/deploy/cpp/include/object_detector.h @@ -14,21 +14,20 @@ #pragma once -#include -#include +#include #include +#include #include -#include +#include #include -#include #include +#include -#include "paddle_inference_api.h" // NOLINT +#include "paddle_inference_api.h" // NOLINT -#include "include/preprocess_op.h" #include "include/config_parser.h" - +#include "include/preprocess_op.h" namespace PaddleDetection { // Object Detection Result @@ -41,48 +40,50 @@ struct ObjectResult { float confidence; }; - // Generate visualization colormap for each class std::vector GenerateColorMap(int num_class); - // Visualiztion Detection Result cv::Mat VisualizeResult(const cv::Mat& img, - const std::vector& results, - const std::vector& lable_list, - const std::vector& colormap); - + const std::vector& results, + const std::vector& lable_list, + const std::vector& colormap); class ObjectDetector { public: - explicit ObjectDetector(const std::string& model_dir, + explicit ObjectDetector(const std::string& model_dir, const std::string& device, - const std::string& run_mode="fluid", - const int gpu_id=0, - bool trt_calib_mode=false) { + const std::string& run_mode = "fluid", + const int gpu_id = 0, + bool trt_calib_mode = false) { config_.load_config(model_dir); threshold_ = config_.draw_threshold_; preprocessor_.Init(config_.preprocess_info_, config_.arch_); - LoadModel(model_dir, device, config_.min_subgraph_size_, 1, run_mode, gpu_id, trt_calib_mode); + LoadModel(model_dir, + device, + config_.min_subgraph_size_, + 1, + run_mode, + gpu_id, + trt_calib_mode); } // Load Paddle inference model - void LoadModel( - const std::string& model_dir, - const std::string& device, - const int min_subgraph_size, - const int batch_size = 1, - const std::string& run_mode = "fluid", - const int gpu_id=0, - bool trt_calib_mode=false); + void LoadModel(const std::string& model_dir, + const std::string& device, + const int min_subgraph_size, + const int batch_size = 1, + const std::string& run_mode = "fluid", + const int gpu_id = 0, + bool trt_calib_mode = false); // Run predictor void Predict(const cv::Mat& im, - const double threshold = 0.5, - const int warmup = 0, - const int repeats = 1, - const bool run_benchmark = false, - std::vector* result = nullptr); + const double threshold = 0.5, + const int warmup = 0, + const int repeats = 1, + const bool run_benchmark = false, + std::vector* result = nullptr); // Get Model Label list const std::vector& GetLabelList() const { @@ -93,9 +94,7 @@ class ObjectDetector { // Preprocess image and copy data to input buffer void Preprocess(const cv::Mat& image_mat); // Postprocess result - void Postprocess( - const cv::Mat& raw_mat, - std::vector* result); + void Postprocess(const cv::Mat& raw_mat, std::vector* result); std::unique_ptr predictor_; Preprocessor preprocessor_; diff --git a/static/deploy/cpp/include/preprocess_op.h b/static/deploy/cpp/include/preprocess_op.h index 1ec10061eaece8ab47cb2520701f3edf4bc057db..ae9c2d6ef0744666e718536465c8ef92a279ee98 100644 --- a/static/deploy/cpp/include/preprocess_op.h +++ b/static/deploy/cpp/include/preprocess_op.h @@ -16,15 +16,15 @@ #include -#include -#include -#include #include +#include #include +#include +#include #include -#include #include +#include namespace PaddleDetection { @@ -38,7 +38,7 @@ class ImageBlob { // Original image width, height, shrink in float format std::vector ori_im_size_f_; // Evaluation image width and height - std::vector eval_im_size_f_; + std::vector eval_im_size_f_; // Scale factor for image size to origin image size std::vector scale_factor_f_; }; @@ -50,7 +50,7 @@ class PreprocessOp { virtual void Run(cv::Mat* im, ImageBlob* data) = 0; }; -class InitInfo : public PreprocessOp{ +class InitInfo : public PreprocessOp { public: virtual void Init(const YAML::Node& item, const std::string& arch) {} virtual void Run(cv::Mat* im, ImageBlob* data); @@ -78,8 +78,8 @@ class Normalize : public PreprocessOp { class Permute : public PreprocessOp { public: virtual void Init(const YAML::Node& item, const std::string& arch) { - to_bgr_ = item["to_bgr"].as(); - is_channel_first_ = item["channel_first"].as(); + to_bgr_ = item["to_bgr"].as(); + is_channel_first_ = item["channel_first"].as(); } virtual void Run(cv::Mat* im, ImageBlob* data); @@ -97,11 +97,11 @@ class Resize : public PreprocessOp { arch_ = arch; interp_ = item["interp"].as(); max_size_ = item["max_size"].as(); - if (item["image_shape"].IsDefined()) { - image_shape_ = item["image_shape"].as>(); + if (item["image_shape"].IsDefined()) { + image_shape_ = item["image_shape"].as>(); } target_size_ = item["target_size"].as(); - } + } // Compute best resize scale for x-dimension, y-dimension std::pair GenerateScale(const cv::Mat& im); @@ -166,4 +166,3 @@ class Preprocessor { }; } // namespace PaddleDetection - diff --git a/static/deploy/cpp/src/main.cc b/static/deploy/cpp/src/main.cc index c1d5693978d1b667e344ab891f7839b50ca2b93f..c235935a4412d875ee9f593a3060e6dfe68e09a9 100644 --- a/static/deploy/cpp/src/main.cc +++ b/static/deploy/cpp/src/main.cc @@ -14,12 +14,12 @@ #include +#include +#include +#include #include #include #include -#include -#include -#include #ifdef _WIN32 #include @@ -29,25 +29,35 @@ #include #endif -#include "include/object_detector.h" #include - +#include "include/object_detector.h" DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(image_file, "", "Path of input image"); DEFINE_string(video_path, "", "Path of input video"); -DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run."); -DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."); +DEFINE_bool( + use_gpu, + false, + "Deprecated, please use `--device` to set the device you want to run."); +DEFINE_string(device, + "CPU", + "Choose the device you want to run, it can be: CPU/GPU/XPU, " + "default is CPU."); DEFINE_bool(use_camera, false, "Use camera or not"); DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)"); DEFINE_int32(gpu_id, 0, "Device id of GPU to execute"); DEFINE_int32(camera_id, -1, "Device id of camera to predict"); -DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark"); +DEFINE_bool(run_benchmark, + false, + "Whether to predict a image_file repeatedly for benchmark"); DEFINE_double(threshold, 0.5, "Threshold of score."); DEFINE_string(output_dir, "output", "Directory of output visualization files."); -DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True"); +DEFINE_bool(trt_calib_mode, + false, + "If the model is produced by TRT offline quantitative calibration, " + "trt_calib_mode need to set True"); -static std::string DirName(const std::string &filepath) { +static std::string DirName(const std::string& filepath) { auto pos = filepath.rfind(OS_PATH_SEP); if (pos == std::string::npos) { return ""; @@ -55,7 +65,7 @@ static std::string DirName(const std::string &filepath) { return filepath.substr(0, pos); } -static bool PathExists(const std::string& path){ +static bool PathExists(const std::string& path) { #ifdef _WIN32 struct _stat buffer; return (_stat(path.c_str(), &buffer) == 0); @@ -92,9 +102,9 @@ void PredictVideo(const std::string& video_path, PaddleDetection::ObjectDetector* det) { // Open video cv::VideoCapture capture; - if (FLAGS_camera_id != -1){ + if (FLAGS_camera_id != -1) { capture.open(FLAGS_camera_id); - }else{ + } else { capture.open(video_path.c_str()); } if (!capture.isOpened()) { @@ -131,18 +141,20 @@ void PredictVideo(const std::string& video_path, break; } det->Predict(frame, 0.5, 0, 1, false, &result); - cv::Mat out_im = PaddleDetection::VisualizeResult( - frame, result, labels, colormap); + cv::Mat out_im = + PaddleDetection::VisualizeResult(frame, result, labels, colormap); for (const auto& item : result) { - printf("In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d %d]\n", - frame_id, - item.class_id, - item.confidence, - item.rect[0], - item.rect[1], - item.rect[2], - item.rect[3]); - } + printf( + "In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d " + "%d]\n", + frame_id, + item.class_id, + item.confidence, + item.rect[0], + item.rect[1], + item.rect[2], + item.rect[3]); + } video_out.write(out_im); frame_id += 1; } @@ -159,26 +171,24 @@ void PredictImage(const std::string& image_path, cv::Mat im = cv::imread(image_path, 1); // Store all detected result std::vector result; - if (run_benchmark) - { + if (run_benchmark) { det->Predict(im, threshold, 100, 100, run_benchmark, &result); - }else - { + } else { det->Predict(im, 0.5, 0, 1, run_benchmark, &result); for (const auto& item : result) { printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", - item.class_id, - item.confidence, - item.rect[0], - item.rect[1], - item.rect[2], - item.rect[3]); + item.class_id, + item.confidence, + item.rect[0], + item.rect[1], + item.rect[2], + item.rect[3]); } // Visualization result auto labels = det->GetLabelList(); auto colormap = PaddleDetection::GenerateColorMap(labels.size()); - cv::Mat vis_img = PaddleDetection::VisualizeResult( - im, result, labels, colormap); + cv::Mat vis_img = + PaddleDetection::VisualizeResult(im, result, labels, colormap); std::vector compression_params; compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); compression_params.push_back(95); @@ -195,30 +205,39 @@ void PredictImage(const std::string& image_path, int main(int argc, char** argv) { // Parsing command-line google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_model_dir.empty() - || (FLAGS_image_file.empty() && FLAGS_video_path.empty())) { + if (FLAGS_model_dir.empty() || + (FLAGS_image_file.empty() && FLAGS_video_path.empty())) { std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ " - << "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl; + << "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl; return -1; } - if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" - || FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) { - std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'."; + if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" || + FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) { + std::cout + << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'."; return -1; } - transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper); - if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) { + transform(FLAGS_device.begin(), + FLAGS_device.end(), + FLAGS_device.begin(), + ::toupper); + if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || + FLAGS_device == "XPU")) { std::cout << "device should be 'CPU', 'GPU' or 'XPU'."; return -1; } if (FLAGS_use_gpu) { - std::cout << "Deprecated, please use `--device` to set the device you want to run."; + std::cout << "Deprecated, please use `--device` to set the device you want " + "to run."; return -1; } // Load model and create a object detector - PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_device, - FLAGS_run_mode, FLAGS_gpu_id, FLAGS_trt_calib_mode); + PaddleDetection::ObjectDetector det(FLAGS_model_dir, + FLAGS_device, + FLAGS_run_mode, + FLAGS_gpu_id, + FLAGS_trt_calib_mode); // Do inference on input video or image if (!FLAGS_video_path.empty() || FLAGS_use_camera) { PredictVideo(FLAGS_video_path, &det); @@ -226,7 +245,11 @@ int main(int argc, char** argv) { if (!PathExists(FLAGS_output_dir)) { MkDirs(FLAGS_output_dir); } - PredictImage(FLAGS_image_file, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir); + PredictImage(FLAGS_image_file, + FLAGS_threshold, + FLAGS_run_benchmark, + &det, + FLAGS_output_dir); } return 0; } diff --git a/static/deploy/cpp/src/object_detector.cc b/static/deploy/cpp/src/object_detector.cc index f257d8021ed10e1f1f3b5fc1a726bc5118e3b13b..96d8a6dd9f7f69fcb1471fc956f8e93bdc3654c2 100644 --- a/static/deploy/cpp/src/object_detector.cc +++ b/static/deploy/cpp/src/object_detector.cc @@ -13,8 +13,8 @@ // limitations under the License. #include // for setprecision -#include #include +#include #include "include/object_detector.h" namespace PaddleDetection { @@ -41,18 +41,18 @@ void ObjectDetector::LoadModel(const std::string& model_dir, } else if (run_mode == "trt_int8") { precision = paddle::AnalysisConfig::Precision::kInt8; } else { - printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'"); + printf( + "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'"); } - config.EnableTensorRtEngine( - 1 << 10, - batch_size, - min_subgraph_size, - precision, - false, - trt_calib_mode); - } - } else if (device == "XPU"){ - config.EnableXpu(10*1024*1024); + config.EnableTensorRtEngine(1 << 10, + batch_size, + min_subgraph_size, + precision, + false, + trt_calib_mode); + } + } else if (device == "XPU") { + config.EnableXpu(10 * 1024 * 1024); } else { config.DisableGpu(); } @@ -88,11 +88,8 @@ cv::Mat VisualizeResult(const cv::Mat& img, int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; double font_scale = 0.5f; float thickness = 0.5; - cv::Size text_size = cv::getTextSize(text, - font_face, - font_scale, - thickness, - nullptr); + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); cv::Point origin; origin.x = roi.x; origin.y = roi.y; @@ -124,9 +121,8 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) { preprocessor_.Run(&im, &inputs_); } -void ObjectDetector::Postprocess( - const cv::Mat& raw_mat, - std::vector* result) { +void ObjectDetector::Postprocess(const cv::Mat& raw_mat, + std::vector* result) { result->clear(); int rh = 1; int rw = 1; @@ -158,11 +154,11 @@ void ObjectDetector::Postprocess( } void ObjectDetector::Predict(const cv::Mat& im, - const double threshold, - const int warmup, - const int repeats, - const bool run_benchmark, - std::vector* result) { + const double threshold, + const int warmup, + const int repeats, + const bool run_benchmark, + std::vector* result) { // Preprocess image Preprocess(im); // Prepare input tensor @@ -189,8 +185,7 @@ void ObjectDetector::Predict(const cv::Mat& im, } } // Run predictor - for (int i = 0; i < warmup; i++) - { + for (int i = 0; i < warmup; i++) { predictor_->ZeroCopyRun(); // Get output tensor auto output_names = predictor_->GetOutputNames(); @@ -206,12 +201,11 @@ void ObjectDetector::Predict(const cv::Mat& im, std::cerr << "[WARNING] No object detected." << std::endl; } output_data_.resize(output_size); - out_tensor->copy_to_cpu(output_data_.data()); + out_tensor->copy_to_cpu(output_data_.data()); } auto start = std::chrono::steady_clock::now(); - for (int i = 0; i < repeats; i++) - { + for (int i = 0; i < repeats; i++) { predictor_->ZeroCopyRun(); // Get output tensor auto output_names = predictor_->GetOutputNames(); @@ -227,15 +221,15 @@ void ObjectDetector::Predict(const cv::Mat& im, std::cerr << "[WARNING] No object detected." << std::endl; } output_data_.resize(output_size); - out_tensor->copy_to_cpu(output_data_.data()); + out_tensor->copy_to_cpu(output_data_.data()); } auto end = std::chrono::steady_clock::now(); std::chrono::duration diff = end - start; float ms = diff.count() / repeats * 1000; printf("Inference: %f ms per batch image\n", ms); // Postprocessing result - if(!run_benchmark) { - Postprocess(im, result); + if (!run_benchmark) { + Postprocess(im, result); } } diff --git a/static/deploy/cpp/src/preprocess_op.cc b/static/deploy/cpp/src/preprocess_op.cc index cec3feb7a11c3b1e6de0aef24ae3eba0037b8152..5017c239a3593a9d0dae97b57fd3b15ccd0671a9 100644 --- a/static/deploy/cpp/src/preprocess_op.cc +++ b/static/deploy/cpp/src/preprocess_op.cc @@ -12,28 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include #include "include/preprocess_op.h" namespace PaddleDetection { void InitInfo::Run(cv::Mat* im, ImageBlob* data) { - data->ori_im_size_ = { - static_cast(im->rows), - static_cast(im->cols) - }; + data->ori_im_size_ = {static_cast(im->rows), static_cast(im->cols)}; data->ori_im_size_f_ = { - static_cast(im->rows), - static_cast(im->cols), - 1.0 - }; + static_cast(im->rows), static_cast(im->cols), 1.0}; data->eval_im_size_f_ = { - static_cast(im->rows), - static_cast(im->cols), - 1.0 - }; + static_cast(im->rows), static_cast(im->cols), 1.0}; data->scale_factor_f_ = {1., 1., 1., 1.}; } @@ -46,11 +37,11 @@ void Normalize::Run(cv::Mat* im, ImageBlob* data) { for (int h = 0; h < im->rows; h++) { for (int w = 0; w < im->cols; w++) { im->at(h, w)[0] = - (im->at(h, w)[0] - mean_[0] ) / scale_[0]; + (im->at(h, w)[0] - mean_[0]) / scale_[0]; im->at(h, w)[1] = - (im->at(h, w)[1] - mean_[1] ) / scale_[1]; + (im->at(h, w)[1] - mean_[1]) / scale_[1]; im->at(h, w)[2] = - (im->at(h, w)[2] - mean_[2] ) / scale_[2]; + (im->at(h, w)[2] - mean_[2]) / scale_[2]; } } } @@ -63,7 +54,8 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) { float* base = (data->im_data_).data(); for (int i = 0; i < rc; ++i) { int cur_c = to_bgr_ ? rc - i - 1 : i; - cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + cur_c * rh * rw), i); + cv::extractChannel( + *im, cv::Mat(rh, rw, CV_32FC1, base + cur_c * rh * rw), i); } } @@ -73,27 +65,22 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) { *im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_); if (max_size_ != 0 && !image_shape_.empty()) { // Padding the image with 0 border - cv::copyMakeBorder( - *im, - *im, - 0, - max_size_ - im->rows, - 0, - max_size_ - im->cols, - cv::BORDER_CONSTANT, - cv::Scalar(0)); + cv::copyMakeBorder(*im, + *im, + 0, + max_size_ - im->rows, + 0, + max_size_ - im->cols, + cv::BORDER_CONSTANT, + cv::Scalar(0)); } - data->eval_im_size_f_ = { - static_cast(im->rows), - static_cast(im->cols), - resize_scale.first - }; - data->scale_factor_f_ = { - resize_scale.first, - resize_scale.second, - resize_scale.first, - resize_scale.second - }; + data->eval_im_size_f_ = {static_cast(im->rows), + static_cast(im->cols), + resize_scale.first}; + data->scale_factor_f_ = {resize_scale.first, + resize_scale.second, + resize_scale.first, + resize_scale.second}; } std::pair Resize::GenerateScale(const cv::Mat& im) { @@ -132,23 +119,14 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) { int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_; int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_; cv::copyMakeBorder( - *im, - *im, - 0, - nh - rh, - 0, - nw - rw, - cv::BORDER_CONSTANT, - cv::Scalar(0)); + *im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::Scalar(0)); (data->eval_im_size_f_)[0] = static_cast(im->rows); (data->eval_im_size_f_)[1] = static_cast(im->cols); } - // Preprocessor op running order const std::vector Preprocessor::RUN_ORDER = { - "InitInfo", "Resize", "Normalize", "PadStride", "Permute" -}; + "InitInfo", "Resize", "Normalize", "PadStride", "Permute"}; void Preprocessor::Run(cv::Mat* im, ImageBlob* data) { for (const auto& name : RUN_ORDER) { diff --git a/static/deploy/lite/run_detection.cc b/static/deploy/lite/run_detection.cc index 109cc2fff8ae99dc1554a65afd74c16644e736e2..088200ee780a13488c825432df024d3162b13854 100644 --- a/static/deploy/lite/run_detection.cc +++ b/static/deploy/lite/run_detection.cc @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include -#include #include +#include #include "opencv2/core.hpp" #include "opencv2/imgcodecs.hpp" #include "opencv2/imgproc.hpp" #include "paddle_api.h" // NOLINT - using namespace paddle::lite_api; // NOLINT using namespace std; @@ -57,13 +56,14 @@ void PrintBenchmarkLog(std::vector det_time, std::cout << "---------------- Perf info ---------------------" << std::endl; std::cout << "Total number of predicted data: " << img_num << " and total time spent(s): " - << std::accumulate(det_time.begin(), det_time.end(), 0) << std::endl; + << std::accumulate(det_time.begin(), det_time.end(), 0) + << std::endl; std::cout << "preproce_time(ms): " << det_time[0] / img_num << ", inference_time(ms): " << det_time[1] / img_num << ", postprocess_time(ms): " << det_time[2] << std::endl; } -std::vector LoadLabels(const std::string &path) { +std::vector LoadLabels(const std::string& path) { std::ifstream file; std::vector labels; file.open(path); @@ -96,18 +96,17 @@ std::vector ReadDict(std::string path) { return m_vec; } -std::vector split(const std::string &str, - const std::string &delim) { +std::vector split(const std::string& str, + const std::string& delim) { std::vector res; - if ("" == str) - return res; - char *strs = new char[str.length() + 1]; + if ("" == str) return res; + char* strs = new char[str.length() + 1]; std::strcpy(strs, str.c_str()); - char *d = new char[delim.length() + 1]; + char* d = new char[delim.length() + 1]; std::strcpy(d, delim.c_str()); - char *p = std::strtok(strs, d); + char* p = std::strtok(strs, d); while (p) { string s = p; res.push_back(s); @@ -128,7 +127,7 @@ std::map LoadConfigTxt(std::string config_path) { return dict; } -void PrintConfig(const std::map &config) { +void PrintConfig(const std::map& config) { std::cout << "=======PaddleDetection lite demo config======" << std::endl; for (auto iter = config.begin(); iter != config.end(); iter++) { std::cout << iter->first << " : " << iter->second << std::endl; @@ -136,7 +135,6 @@ void PrintConfig(const std::map &config) { std::cout << "===End of PaddleDetection lite demo config===" << std::endl; } - // fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up void neon_mean_scale(const float* din, float* dout, @@ -182,11 +180,11 @@ void neon_mean_scale(const float* din, } std::vector visualize_result( - const float* data, - int count, - float thresh, - cv::Mat& image, - const std::vector &class_names) { + const float* data, + int count, + float thresh, + cv::Mat& image, + const std::vector& class_names) { if (data == nullptr) { std::cerr << "[ERROR] data can not be nullptr\n"; exit(1); @@ -258,54 +256,59 @@ std::shared_ptr LoadModel(std::string model_file, } ImageBlob prepare_imgdata(const cv::Mat& img, - std::map config) { + std::map config) { ImageBlob img_data; std::vector target_size_; std::vector size_str = split(config.at("Resize"), ","); - transform(size_str.begin(), size_str.end(), back_inserter(target_size_), - [](std::string const& s){return stoi(s);}); + transform(size_str.begin(), + size_str.end(), + back_inserter(target_size_), + [](std::string const& s) { return stoi(s); }); int width = target_size_[0]; int height = target_size_[1]; - img_data.im_shape_ = { - static_cast(target_size_[0]), - static_cast(target_size_[1]) - }; + img_data.im_shape_ = {static_cast(target_size_[0]), + static_cast(target_size_[1])}; std::vector mean_; std::vector scale_; std::vector mean_str = split(config.at("mean"), ","); std::vector std_str = split(config.at("std"), ","); - transform(mean_str.begin(), mean_str.end(), back_inserter(mean_), - [](std::string const& s){return stof(s);}); - transform(std_str.begin(), std_str.end(), back_inserter(scale_), - [](std::string const& s){return stof(s);}); + transform(mean_str.begin(), + mean_str.end(), + back_inserter(mean_), + [](std::string const& s) { return stof(s); }); + transform(std_str.begin(), + std_str.end(), + back_inserter(scale_), + [](std::string const& s) { return stof(s); }); img_data.mean_ = mean_; img_data.scale_ = scale_; return img_data; } - void preprocess(const cv::Mat& img, const ImageBlob img_data, float* data) { cv::Mat rgb_img; cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB); - cv::resize( - rgb_img, rgb_img, cv::Size(img_data.im_shape_[0],img_data.im_shape_[1]), - 0.f, 0.f, cv::INTER_CUBIC); + cv::resize(rgb_img, + rgb_img, + cv::Size(img_data.im_shape_[0], img_data.im_shape_[1]), + 0.f, + 0.f, + cv::INTER_CUBIC); cv::Mat imgf; rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f); const float* dimg = reinterpret_cast(imgf.data); - neon_mean_scale( - dimg, data, int(img_data.im_shape_[0] * img_data.im_shape_[1]), - img_data.mean_, img_data.scale_); + neon_mean_scale(dimg, + data, + int(img_data.im_shape_[0] * img_data.im_shape_[1]), + img_data.mean_, + img_data.scale_); } - void RunModel(std::map config, std::string img_path, const int repeats, std::vector* times) { - std::string model_file = config.at("model_file"); std::string label_path = config.at("label_path"); // Load Labels @@ -334,14 +337,12 @@ void RunModel(std::map config, // 2. Run predictor // warm up - for (int i = 0; i < repeats / 2; i++) - { + for (int i = 0; i < repeats / 2; i++) { predictor->Run(); } auto inference_start = std::chrono::steady_clock::now(); - for (int i = 0; i < repeats; i++) - { + for (int i = 0; i < repeats; i++) { predictor->Run(); } auto inference_end = std::chrono::steady_clock::now(); diff --git a/static/deploy/python/infer.py b/static/deploy/python/infer.py index bebdf0508402736c58dae22d6200957372e2b9fa..45671cad0eca0f637c7e4c1e040fbab0c08d0a90 100644 --- a/static/deploy/python/infer.py +++ b/static/deploy/python/infer.py @@ -530,7 +530,7 @@ def predict_video(detector, camera_id): fps = 30 width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(* 'mp4v') if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) out_path = os.path.join(FLAGS.output_dir, video_name) @@ -660,6 +660,8 @@ if __name__ == '__main__': assert FLAGS.device in ['CPU', 'GPU', 'XPU' ], "device should be CPU, GPU or XPU" assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device" - assert not (FLAGS.enable_mkldnn==False and FLAGS.enable_mkldnn_bfloat16==True),"To turn on mkldnn_bfloat, please set both enable_mkldnn and enable_mkldnn_bfloat16 True" + assert not ( + FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True + ), "To turn on mkldnn_bfloat, please set both enable_mkldnn and enable_mkldnn_bfloat16 True" main() diff --git a/static/ppdet/ext_op/src/bottom_pool_op.cc b/static/ppdet/ext_op/src/bottom_pool_op.cc index 6a867d1f127a34a9f93a0d0719a5dba039466f24..6b425fee174f1137822c4f424591e0c3166faea4 100644 --- a/static/ppdet/ext_op/src/bottom_pool_op.cc +++ b/static/ppdet/ext_op/src/bottom_pool_op.cc @@ -18,7 +18,7 @@ namespace operators { using Tensor = framework::Tensor; class BottomPoolOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -27,7 +27,7 @@ public: ctx->ShareDim("X", /*->*/ "Output"); } -protected: + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(ctx.Input("X")->type(), @@ -36,10 +36,9 @@ protected: }; class BottomPoolOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: void Make() override { - AddInput("X", - "Input with shape (batch, C, H, W)"); + AddInput("X", "Input with shape (batch, C, H, W)"); AddOutput("MaxMap", "Max map with index of maximum value of input"); AddOutput("Output", "output with same shape as input(X)"); AddComment( @@ -52,10 +51,10 @@ The output has the same shape with input. }; class BottomPoolOpGrad : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; -protected: + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); @@ -75,10 +74,10 @@ protected: template class BottomPoolGradDescMaker : public framework::SingleGradOpMaker { -public: + public: using framework::SingleGradOpMaker::SingleGradOpMaker; -protected: + protected: void Apply(GradOpPtr op) const override { op->SetType("bottom_pool_grad"); op->SetInput("X", this->Input("X")); diff --git a/static/ppdet/ext_op/src/bottom_pool_op.cu b/static/ppdet/ext_op/src/bottom_pool_op.cu index 4912ec3c0effb2d924111203168da821aae16b19..faab8d588c67f3abc8bf2e98c5c33cb20bfdf56a 100644 --- a/static/ppdet/ext_op/src/bottom_pool_op.cu +++ b/static/ppdet/ext_op/src/bottom_pool_op.cu @@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/memory/memory.h" -#include +#include "paddle/fluid/platform/cuda_primitives.h" #include "util.cu.h" namespace paddle { @@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) { template class BottomPoolOpCUDAKernel : public framework::OpKernel { -public: - void Compute(const framework::ExecutionContext &ctx) const override { + public: + void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *x = ctx.Input("X"); - auto *max_map = ctx.Output("MaxMap"); - auto *output = ctx.Output("Output"); - auto *x_data = x->data(); + auto* x = ctx.Input("X"); + auto* max_map = ctx.Output("MaxMap"); + auto* output = ctx.Output("Output"); + auto* x_data = x->data(); auto x_dims = x->dims(); int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; @@ -47,22 +47,31 @@ public: int num = x->numel(); auto& dev_ctx = ctx.cuda_device_context(); - int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); - T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + int* max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T* output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); int threads = kNumCUDAThreads; int blocks = NumBlocks(num / height); - + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); - GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, false, max_val_data, max_ind_data, max_map_data); + GetMaxInfo<<>>(x->data(), + NC_num, + height, + width, + 2, + false, + max_val_data, + max_ind_data, + max_map_data); blocks = NumBlocks(num); - ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + ScatterAddFw<<>>( + x->data(), max_map_data, NC_num, height, width, 2, output_data); } }; @@ -75,20 +84,28 @@ class BottomPoolGradOpCUDAKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Output")); auto* in_grad = ctx.Output(framework::GradVarName("X")); auto x_dims = x->dims(); - + auto& dev_ctx = ctx.cuda_device_context(); T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; int width = x_dims[3]; int grad_num = in_grad->numel(); int blocks = NumBlocks(grad_num); - FillConstant<<>>(in_grad_data, 0, grad_num); - - ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + FillConstant<<>>( + in_grad_data, 0, grad_num); + + ScatterAddBw<<>>( + out_grad->data(), + max_map->data(), + NC_num, + height, + width, + 2, + in_grad_data); } }; diff --git a/static/ppdet/ext_op/src/left_pool_op.cc b/static/ppdet/ext_op/src/left_pool_op.cc index c2a8f169fe0f11b7cebfdb16fb38e111a820df48..32fe2c9a95980e612bf11038320dd5ed77a514e0 100644 --- a/static/ppdet/ext_op/src/left_pool_op.cc +++ b/static/ppdet/ext_op/src/left_pool_op.cc @@ -18,7 +18,7 @@ namespace operators { using Tensor = framework::Tensor; class LeftPoolOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -27,7 +27,7 @@ public: ctx->ShareDim("X", /*->*/ "Output"); } -protected: + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(ctx.Input("X")->type(), @@ -36,10 +36,9 @@ protected: }; class LeftPoolOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: void Make() override { - AddInput("X", - "Input with shape (batch, C, H, W)"); + AddInput("X", "Input with shape (batch, C, H, W)"); AddOutput("MaxMap", "Max map with index of maximum value of input"); AddOutput("Output", "output with same shape as input(X)"); AddComment( @@ -52,10 +51,10 @@ The output has the same shape with input. }; class LeftPoolOpGrad : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; -protected: + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); @@ -75,10 +74,10 @@ protected: template class LeftPoolGradDescMaker : public framework::SingleGradOpMaker { -public: + public: using framework::SingleGradOpMaker::SingleGradOpMaker; -protected: + protected: void Apply(GradOpPtr op) const override { op->SetType("left_pool_grad"); op->SetInput("X", this->Input("X")); diff --git a/static/ppdet/ext_op/src/left_pool_op.cu b/static/ppdet/ext_op/src/left_pool_op.cu index a5e9323adc6e268bf6572cf5470dec841664b6ec..06995da6a71b8eddeda618515afdca9be8aa1d4c 100644 --- a/static/ppdet/ext_op/src/left_pool_op.cu +++ b/static/ppdet/ext_op/src/left_pool_op.cu @@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/memory/memory.h" -#include +#include "paddle/fluid/platform/cuda_primitives.h" #include "util.cu.h" namespace paddle { @@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) { template class LeftPoolOpCUDAKernel : public framework::OpKernel { -public: - void Compute(const framework::ExecutionContext &ctx) const override { + public: + void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *x = ctx.Input("X"); - auto *max_map = ctx.Output("MaxMap"); - auto *output = ctx.Output("Output"); - auto *x_data = x->data(); + auto* x = ctx.Input("X"); + auto* max_map = ctx.Output("MaxMap"); + auto* output = ctx.Output("Output"); + auto* x_data = x->data(); auto x_dims = x->dims(); int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; @@ -47,10 +47,10 @@ public: int num = x->numel(); auto& dev_ctx = ctx.cuda_device_context(); - int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); - T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + int* max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T* output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int blocks = NumBlocks(num / width); @@ -59,11 +59,19 @@ public: auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); - GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, true, max_val_data, max_ind_data, max_map_data); + GetMaxInfo<<>>(x->data(), + NC_num, + height, + width, + 3, + true, + max_val_data, + max_ind_data, + max_map_data); blocks = NumBlocks(num); - ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); - + ScatterAddFw<<>>( + x->data(), max_map_data, NC_num, height, width, 3, output_data); } }; @@ -76,24 +84,31 @@ class LeftPoolGradOpCUDAKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Output")); auto* in_grad = ctx.Output(framework::GradVarName("X")); auto x_dims = x->dims(); - + auto& dev_ctx = ctx.cuda_device_context(); T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int NC_num = x_dims[0] * x_dims[1]; - int height = x_dims[2]; + int height = x_dims[2]; int width = x_dims[3]; int grad_num = in_grad->numel(); int blocks = NumBlocks(grad_num); - FillConstant<<>>(in_grad_data, 0, grad_num); - - ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + FillConstant<<>>( + in_grad_data, 0, grad_num); + + ScatterAddBw<<>>( + out_grad->data(), + max_map->data(), + NC_num, + height, + width, + 3, + in_grad_data); } }; - } // namespace operators } // namespace paddle diff --git a/static/ppdet/ext_op/src/right_pool_op.cc b/static/ppdet/ext_op/src/right_pool_op.cc index 6bf74a1b08878724e388ae66ffe34c1ae0c65ec8..80ca5f900d327ed54631fc686199c3e8dfe4b72b 100644 --- a/static/ppdet/ext_op/src/right_pool_op.cc +++ b/static/ppdet/ext_op/src/right_pool_op.cc @@ -18,7 +18,7 @@ namespace operators { using Tensor = framework::Tensor; class RightPoolOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -27,7 +27,7 @@ public: ctx->ShareDim("X", /*->*/ "Output"); } -protected: + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(ctx.Input("X")->type(), @@ -36,11 +36,10 @@ protected: }; class RightPoolOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: void Make() override { - AddInput("X", - "Input with shape (batch, C, H, W)"); - AddOutput("MaxMap", "Max map with index of maximum value of input"); + AddInput("X", "Input with shape (batch, C, H, W)"); + AddOutput("MaxMap", "Max map with index of maximum value of input"); AddOutput("Output", "output with same shape as input(X)"); AddComment( R"Doc( @@ -52,10 +51,10 @@ The output has the same shape with input. }; class RightPoolOpGrad : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; -protected: + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); @@ -75,10 +74,10 @@ protected: template class RightPoolGradDescMaker : public framework::SingleGradOpMaker { -public: + public: using framework::SingleGradOpMaker::SingleGradOpMaker; -protected: + protected: void Apply(GradOpPtr op) const override { op->SetType("right_pool_grad"); op->SetInput("X", this->Input("X")); diff --git a/static/ppdet/ext_op/src/right_pool_op.cu b/static/ppdet/ext_op/src/right_pool_op.cu index 08a52ecf1eec9816c8a29e50452344d522b75c49..b0fd634f91853922eaf553a1f418f04891c65e1b 100644 --- a/static/ppdet/ext_op/src/right_pool_op.cu +++ b/static/ppdet/ext_op/src/right_pool_op.cu @@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/memory/memory.h" -#include +#include "paddle/fluid/platform/cuda_primitives.h" #include "util.cu.h" namespace paddle { @@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) { template class RightPoolOpCUDAKernel : public framework::OpKernel { -public: - void Compute(const framework::ExecutionContext &ctx) const override { + public: + void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *x = ctx.Input("X"); - auto *max_map = ctx.Output("MaxMap"); - auto *output = ctx.Output("Output"); - auto *x_data = x->data(); + auto* x = ctx.Input("X"); + auto* max_map = ctx.Output("MaxMap"); + auto* output = ctx.Output("Output"); + auto* x_data = x->data(); auto x_dims = x->dims(); int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; @@ -47,23 +47,31 @@ public: int num = x->numel(); auto& dev_ctx = ctx.cuda_device_context(); - int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); - T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + int* max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T* output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int blocks = NumBlocks(num / width); - + auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T)); T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int)); int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); - GetMaxInfo<<>>(x->data(), NC_num, height, width, 3, false, max_val_data, max_ind_data, max_map_data); + GetMaxInfo<<>>(x->data(), + NC_num, + height, + width, + 3, + false, + max_val_data, + max_ind_data, + max_map_data); blocks = NumBlocks(num); - ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 3, output_data); - + ScatterAddFw<<>>( + x->data(), max_map_data, NC_num, height, width, 3, output_data); } }; @@ -76,20 +84,28 @@ class RightPoolGradOpCUDAKernel : public framework::OpKernel { auto* out_grad = ctx.Input(framework::GradVarName("Output")); auto* in_grad = ctx.Output(framework::GradVarName("X")); auto x_dims = x->dims(); - + auto& dev_ctx = ctx.cuda_device_context(); T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; int width = x_dims[3]; int grad_num = in_grad->numel(); int blocks = NumBlocks(grad_num); - FillConstant<<>>(in_grad_data, 0, grad_num); - - ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 3, in_grad_data); + FillConstant<<>>( + in_grad_data, 0, grad_num); + + ScatterAddBw<<>>( + out_grad->data(), + max_map->data(), + NC_num, + height, + width, + 3, + in_grad_data); } }; diff --git a/static/ppdet/ext_op/src/top_pool_op.cc b/static/ppdet/ext_op/src/top_pool_op.cc old mode 100755 new mode 100644 index 29cba6660193c3bfe861dc81b03edbbb9368ea19..956b2cff1657d609c1524b792dec5e7db5b47352 --- a/static/ppdet/ext_op/src/top_pool_op.cc +++ b/static/ppdet/ext_op/src/top_pool_op.cc @@ -18,7 +18,7 @@ namespace operators { using Tensor = framework::Tensor; class TopPoolOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { @@ -27,19 +27,18 @@ public: ctx->ShareDim("X", /*->*/ "Output"); } -protected: + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), + return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace()); } }; class TopPoolOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: void Make() override { - AddInput("X", - "Input with shape (batch, C, H, W)"); + AddInput("X", "Input with shape (batch, C, H, W)"); AddOutput("MaxMap", "Max map with index of maximum value of input"); AddOutput("Output", "Output with same shape as input(X)"); AddComment( @@ -52,16 +51,16 @@ The output has the same shape with input. }; class TopPoolOpGrad : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; -protected: + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), "Input(Output@GRAD) should not be null"); - + auto out_grad_name = framework::GradVarName("Output"); ctx->ShareDim(out_grad_name, framework::GradVarName("X")); } diff --git a/static/ppdet/ext_op/src/top_pool_op.cu b/static/ppdet/ext_op/src/top_pool_op.cu old mode 100755 new mode 100644 index f6237fe798098ad0c6cb25597de849e65383ab78..af30d5f11c926c2458cd4e55893318f13686c57c --- a/static/ppdet/ext_op/src/top_pool_op.cu +++ b/static/ppdet/ext_op/src/top_pool_op.cu @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/memory/memory.h" -#include +#include "paddle/fluid/platform/cuda_primitives.h" #include "util.cu.h" namespace paddle { @@ -33,14 +33,14 @@ static inline int NumBlocks(const int N) { template class TopPoolOpCUDAKernel : public framework::OpKernel { -public: - void Compute(const framework::ExecutionContext &ctx) const override { + public: + void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *x = ctx.Input("X"); - auto *max_map = ctx.Output("MaxMap"); - auto *output = ctx.Output("Output"); - auto *x_data = x->data(); + auto* x = ctx.Input("X"); + auto* max_map = ctx.Output("MaxMap"); + auto* output = ctx.Output("Output"); + auto* x_data = x->data(); auto x_dims = x->dims(); int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; @@ -48,22 +48,31 @@ public: int num = x->numel(); auto& dev_ctx = ctx.cuda_device_context(); - int *max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); - T *output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); + int* max_map_data = max_map->mutable_data(x_dims, dev_ctx.GetPlace()); + T* output_data = output->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int blocks = NumBlocks(num / height); - + auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T)); T* max_val_data = reinterpret_cast(max_val_ptr->ptr()); auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int)); int* max_ind_data = reinterpret_cast(max_ind_ptr->ptr()); - GetMaxInfo<<>>(x->data(), NC_num, height, width, 2, true, max_val_data, max_ind_data, max_map_data); + GetMaxInfo<<>>(x->data(), + NC_num, + height, + width, + 2, + true, + max_val_data, + max_ind_data, + max_map_data); blocks = NumBlocks(num); - ScatterAddFw<<>>(x->data(), max_map_data, NC_num, height, width, 2, output_data); + ScatterAddFw<<>>( + x->data(), max_map_data, NC_num, height, width, 2, output_data); } }; @@ -79,16 +88,24 @@ class TopPoolGradOpCUDAKernel : public framework::OpKernel { auto& dev_ctx = ctx.cuda_device_context(); T* in_grad_data = in_grad->mutable_data(x_dims, dev_ctx.GetPlace()); auto gpu_place = boost::get(dev_ctx.GetPlace()); - + int threads = kNumCUDAThreads; int NC_num = x_dims[0] * x_dims[1]; int height = x_dims[2]; int width = x_dims[3]; int grad_num = in_grad->numel(); int blocks = NumBlocks(grad_num); - FillConstant<<>>(in_grad_data, 0, grad_num); - - ScatterAddBw<<>>(out_grad->data(), max_map->data(), NC_num, height, width, 2, in_grad_data); + FillConstant<<>>( + in_grad_data, 0, grad_num); + + ScatterAddBw<<>>( + out_grad->data(), + max_map->data(), + NC_num, + height, + width, + 2, + in_grad_data); } }; diff --git a/static/ppdet/ext_op/src/util.cu.h b/static/ppdet/ext_op/src/util.cu.h index 615e45a7891fca7f89eea5824b4691d33a1fe5ef..60b63fa367ffa926556422214216d3e7099fa5f3 100644 --- a/static/ppdet/ext_op/src/util.cu.h +++ b/static/ppdet/ext_op/src/util.cu.h @@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/memory/memory.h" -#include +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -27,15 +27,18 @@ using framework::Tensor; template __global__ void FillConstant(T* x, int num, int fill_num) { - CUDA_1D_KERNEL_LOOP(i, fill_num) { - x[i] = static_cast(num); - } + CUDA_1D_KERNEL_LOOP(i, fill_num) { x[i] = static_cast(num); } } template -__global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int W, - const int axis, const int start, const int end, - T* output) { +__global__ void SliceOnAxis(const T* x, + const int NC_num, + const int H, + const int W, + const int axis, + const int start, + const int end, + T* output) { int HW_num = H * W; int length = axis == 2 ? W : H; int sliced_len = end - start; @@ -44,22 +47,28 @@ __global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { int NC_id = i / cur_HW_num; int HW_id = i % cur_HW_num; - if (axis == 2){ + if (axis == 2) { output[i] = x[NC_id * HW_num + start * W + HW_id]; } else if (axis == 3) { int col = HW_id % sliced_len; int row = HW_id / sliced_len; output[i] = x[NC_id * HW_num + row * W + start + col]; } - } + } } template -__global__ void MaxOut(const T* input, const int next_ind, const int NC_num, - const int H, const int W, const int axis, - const int start, const int end, T* output) { +__global__ void MaxOut(const T* input, + const int next_ind, + const int NC_num, + const int H, + const int W, + const int axis, + const int start, + const int end, + T* output) { int HW_num = H * W; - int length = axis == 2 ? W : H; + int length = axis == 2 ? W : H; T cur = static_cast(0.); T next = static_cast(0.); T max_v = static_cast(0.); @@ -69,11 +78,11 @@ __global__ void MaxOut(const T* input, const int next_ind, const int NC_num, CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) { int NC_id = i / cur_HW_num; int HW_id = i % cur_HW_num; - - if (axis == 2){ + + if (axis == 2) { cur = input[NC_id * HW_num + start * W + HW_id]; next = input[NC_id * HW_num + next_ind * W + HW_id]; - max_v = cur > next ? cur : next; + max_v = cur > next ? cur : next; output[NC_id * HW_num + start * W + HW_id] = max_v; } else if (axis == 3) { int col = HW_id % sliced_len; @@ -88,11 +97,16 @@ __global__ void MaxOut(const T* input, const int next_ind, const int NC_num, } template -__global__ void UpdateMaxInfo(const T* input, const int NC_num, - const int H, const int W, const int axis, - const int index, T* max_val, int* max_ind) { +__global__ void UpdateMaxInfo(const T* input, + const int NC_num, + const int H, + const int W, + const int axis, + const int index, + T* max_val, + int* max_ind) { int length = axis == 2 ? W : H; - int HW_num = H * W; + int HW_num = H * W; T val = static_cast(0.); CUDA_1D_KERNEL_LOOP(i, NC_num * length) { int NC_id = i / length; @@ -111,82 +125,104 @@ __global__ void UpdateMaxInfo(const T* input, const int NC_num, } template -__global__ void ScatterAddOnAxis(const T* input, const int start, const int* max_ind, const int NC_num, const int H, const int W, const int axis, T* output) { +__global__ void ScatterAddOnAxis(const T* input, + const int start, + const int* max_ind, + const int NC_num, + const int H, + const int W, + const int axis, + T* output) { int length = axis == 2 ? W : H; int HW_num = H * W; - CUDA_1D_KERNEL_LOOP(i, NC_num * length) { + CUDA_1D_KERNEL_LOOP(i, NC_num * length) { int NC_id = i / length; int length_id = i % length; int id_ = max_ind[i]; if (axis == 2) { - platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, input[NC_id * HW_num + start * W + length_id]); - //output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + start * W + length_id]; + platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, + input[NC_id * HW_num + start * W + length_id]); + // output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + + // start * W + length_id]; } else if (axis == 3) { - platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, input[NC_id * HW_num + length_id * W + start]); - //output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + length_id * W + start]; + platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, + input[NC_id * HW_num + length_id * W + start]); + // output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + + // length_id * W + start]; } __syncthreads(); } } template -__global__ void GetMaxInfo(const T* input, const int NC_num, - const int H, const int W, const int axis, - const bool reverse, T* max_val, int* max_ind, +__global__ void GetMaxInfo(const T* input, + const int NC_num, + const int H, + const int W, + const int axis, + const bool reverse, + T* max_val, + int* max_ind, int* max_map) { - int start = 0; - int end = axis == 2 ? H: W; - int s = reverse ? end-1 : start; - int e = reverse ? start-1 : end; - int step = reverse ? -1 : 1; - int len = axis == 2 ? W : H; - int loc = 0; - T val = static_cast(0.); - for (int i = s; ; ) { - if (i == s) { - CUDA_1D_KERNEL_LOOP(j, NC_num * len) { - int NC_id = j / len; - int len_id = j % len; - if (axis == 2) { - loc = NC_id * H * W + i * W + len_id; - } else if (axis == 3){ - loc = NC_id * H * W + len_id * W + i; - } - max_ind[j] = i; - max_map[loc] = max_ind[j]; - max_val[j] = input[loc]; - __syncthreads(); - } - } else { - CUDA_1D_KERNEL_LOOP(j, NC_num * len) { - int NC_id = j / len; - int len_id = j % len; - - if (axis == 2) { - loc = NC_id * H * W + i * W + len_id; - } else if (axis == 3){ - loc = NC_id * H * W + len_id * W + i; - } - val = input[loc]; - T max_v = max_val[j]; - if (val > max_v) { - max_val[j] = val; - max_map[loc] = i; - max_ind[j] = i; - } else { - max_map[loc] = max_ind[j]; - } - __syncthreads(); - } - } - i += step; - if (s < e && i >= e) break; - if (s > e && i <= e) break; - } + int start = 0; + int end = axis == 2 ? H : W; + int s = reverse ? end - 1 : start; + int e = reverse ? start - 1 : end; + int step = reverse ? -1 : 1; + int len = axis == 2 ? W : H; + int loc = 0; + T val = static_cast(0.); + for (int i = s;;) { + if (i == s) { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3) { + loc = NC_id * H * W + len_id * W + i; + } + max_ind[j] = i; + max_map[loc] = max_ind[j]; + max_val[j] = input[loc]; + __syncthreads(); + } + } else { + CUDA_1D_KERNEL_LOOP(j, NC_num * len) { + int NC_id = j / len; + int len_id = j % len; + + if (axis == 2) { + loc = NC_id * H * W + i * W + len_id; + } else if (axis == 3) { + loc = NC_id * H * W + len_id * W + i; + } + val = input[loc]; + T max_v = max_val[j]; + if (val > max_v) { + max_val[j] = val; + max_map[loc] = i; + max_ind[j] = i; + } else { + max_map[loc] = max_ind[j]; + } + __syncthreads(); + } + } + i += step; + if (s < e && i >= e) break; + if (s > e && i <= e) break; + } } template -__global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ +__global__ void ScatterAddFw(const T* input, + const int* max_map, + const int NC_num, + const int H, + const int W, + const int axis, + T* output) { CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { int loc = max_map[i]; int NC_id = i / (H * W); @@ -202,7 +238,13 @@ __global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_nu } template -__global__ void ScatterAddBw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){ +__global__ void ScatterAddBw(const T* input, + const int* max_map, + const int NC_num, + const int H, + const int W, + const int axis, + T* output) { CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) { int loc = max_map[i]; int NC_id = i / (H * W);