未验证 提交 c103d025 编写于 作者: W wangguanzhong 提交者: GitHub

Clean fluid (#6075)

* clean fluid

* mv static to legacy

* remove yolo box

* revert legacy dir

* revert static link

* update in_dynamic_mode

* clean iou_similarity, collect_fpn_proposals, bipartite_match
上级 ba3ebe20
......@@ -23,7 +23,7 @@ else:
import numpy as np
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.fluid.dataloader.collate import default_collate_fn
from .utils import default_collate_fn
from ppdet.core.workspace import register
from . import transform
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
try:
from collections.abc import Sequence, Mapping
except:
from collections import Sequence, Mapping
def default_collate_fn(batch):
"""
Default batch collating function for :code:`paddle.io.DataLoader`,
get input data as a list of sample datas, each element in list
if the data of a sample, and sample data should composed of list,
dictionary, string, number, numpy array, this
function will parse input data recursively and stack number,
numpy array and paddle.Tensor datas as batch datas. e.g. for
following input data:
[{'image': np.array(shape=[3, 224, 224]), 'label': 1},
{'image': np.array(shape=[3, 224, 224]), 'label': 3},
{'image': np.array(shape=[3, 224, 224]), 'label': 4},
{'image': np.array(shape=[3, 224, 224]), 'label': 5},]
This default collate function zipped each number and numpy array
field together and stack each field as the batch field as follows:
{'image': np.array(shape=[4, 3, 224, 224]), 'label': np.array([1, 3, 4, 5])}
Args:
batch(list of sample data): batch should be a list of sample data.
Returns:
Batched data: batched each number, numpy array and paddle.Tensor
in input data.
"""
sample = batch[0]
if isinstance(sample, np.ndarray):
batch = np.stack(batch, axis=0)
return batch
elif isinstance(sample, numbers.Number):
batch = np.array(batch)
return batch
elif isinstance(sample, (str, bytes)):
return batch
elif isinstance(sample, Mapping):
return {
key: default_collate_fn([d[key] for d in batch])
for key in sample
}
elif isinstance(sample, Sequence):
sample_fields_num = len(sample)
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
raise RuntimeError(
"fileds number not same among samples in a batch")
return [default_collate_fn(fields) for fields in zip(*batch)]
raise TypeError("batch data con only contains: tensor, numpy.ndarray, "
"dict, list, number, but got {}".format(type(sample)))
......@@ -23,7 +23,6 @@ import random
import paddle
import paddle.nn.functional as F
import paddle.distributed as dist
from ppdet.modeling.ops import paddle_distributed_is_initialized
__all__ = ['YOLOX']
......
......@@ -22,8 +22,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..ops import iou_similarity
from ..bbox_utils import iou_similarity as batch_iou_similarity
from ..bbox_utils import iou_similarity, batch_iou_similarity
from ..bbox_utils import bbox_center
from .utils import (check_points_inside_bboxes, compute_max_iou_anchor,
compute_max_iou_gt)
......
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import iou_similarity
from ..bbox_utils import batch_iou_similarity
from .utils import (gather_topk_anchors, check_points_inside_bboxes,
compute_max_iou_anchor)
......
......@@ -482,8 +482,7 @@ class BasicLayer(nn.Layer):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = paddle.fluid.layers.zeros(
[1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1
img_mask = paddle.zeros([1, Hp, Wp, 1], dtype='float32') # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
......
......@@ -278,8 +278,8 @@ def decode_yolo(box, anchor, downsample_ratio):
return [x1, y1, w1, h1]
def iou_similarity(box1, box2, eps=1e-9):
"""Calculate iou of box1 and box2
def batch_iou_similarity(box1, box2, eps=1e-9):
"""Calculate iou of box1 and box2 in batch
Args:
box1 (Tensor): box with the shape [N, M1, 4]
......@@ -866,3 +866,26 @@ def bbox2delta_v2(src_boxes,
stds = paddle.to_tensor(stds, place=src_boxes.place)
deltas = (deltas - means) / stds
return deltas
def iou_similarity(box1, box2, eps=1e-10):
"""Calculate iou of box1 and box2
Args:
box1 (Tensor): box with the shape [M1, 4]
box2 (Tensor): box with the shape [M2, 4]
Return:
iou (Tensor): iou between box1 and box2 with the shape [M1, M2]
"""
box1 = box1.unsqueeze(1) # [M1, 4] -> [M1, 1, 4]
box2 = box2.unsqueeze(0) # [M2, 4] -> [1, M2, 4]
px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4]
gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4]
x1y1 = paddle.maximum(px1y1, gx1y1)
x2y2 = paddle.minimum(px2y2, gx2y2)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps
return overlap / union
......@@ -23,7 +23,6 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from paddle.fluid.dygraph import parallel_helper
from ppdet.modeling.ops import get_static_shape
from ..initializer import normal_
......@@ -726,8 +725,7 @@ class PicoHeadV2(GFLHead):
loss_dfl = paddle.zeros([1])
avg_factor = flatten_assigned_scores.sum()
if paddle.fluid.core.is_compiled_with_dist(
) and parallel_helper._is_parallel_ctx_initialized():
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(avg_factor)
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
......
......@@ -22,7 +22,7 @@ from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
from ppdet.modeling.ops import get_static_shape, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
__all__ = ['PPYOLOEHead']
......@@ -343,7 +343,7 @@ class PPYOLOEHead(nn.Layer):
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
assigned_scores_sum = assigned_scores.sum()
if paddle_distributed_is_initialized():
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(assigned_scores_sum)
assigned_scores_sum = paddle.clip(
assigned_scores_sum / paddle.distributed.get_world_size(),
......
......@@ -554,10 +554,15 @@ class YOLOBox(object):
origin_shape = im_shape / scale_factor
origin_shape = paddle.cast(origin_shape, 'int32')
for i, head_out in enumerate(yolo_head_out):
boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
self.num_classes, self.conf_thresh,
self.downsample_ratio // 2**i,
self.clip_bbox, self.scale_x_y)
boxes, scores = paddle.vision.ops.yolo_box(
head_out,
origin_shape,
anchors[i],
self.num_classes,
self.conf_thresh,
self.downsample_ratio // 2**i,
self.clip_bbox,
scale_x_y=self.scale_x_y)
boxes_list.append(boxes)
scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
yolo_boxes = paddle.concat(boxes_list, axis=1)
......@@ -622,94 +627,6 @@ class SSDBox(object):
return output_boxes, output_scores
@register
@serializable
class AnchorGrid(object):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale: base anchor scale. Default: 4
num_scales: number of anchor scales. Default: 3
aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
"""
def __init__(self,
image_size=512,
min_level=3,
max_level=7,
anchor_base_scale=4,
num_scales=3,
aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
super(AnchorGrid, self).__init__()
if isinstance(image_size, Integral):
self.image_size = [image_size, image_size]
else:
self.image_size = image_size
for dim in self.image_size:
assert dim % 2 ** max_level == 0, \
"image size should be multiple of the max level stride"
self.min_level = min_level
self.max_level = max_level
self.anchor_base_scale = anchor_base_scale
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
@property
def base_cell(self):
if not hasattr(self, '_base_cell'):
self._base_cell = self.make_cell()
return self._base_cell
def make_cell(self):
scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
scales = np.array(scales)
ratios = np.array(self.aspect_ratios)
ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
return anchors
def make_grid(self, stride):
cell = self.base_cell * stride * self.anchor_base_scale
x_steps = np.arange(stride // 2, self.image_size[1], stride)
y_steps = np.arange(stride // 2, self.image_size[0], stride)
offset_x, offset_y = np.meshgrid(x_steps, y_steps)
offset_x = offset_x.flatten()
offset_y = offset_y.flatten()
offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
offsets = offsets[:, np.newaxis, :]
return (cell + offsets).reshape(-1, 4)
def generate(self):
return [
self.make_grid(2**l)
for l in range(self.min_level, self.max_level + 1)
]
def __call__(self):
if not hasattr(self, '_anchor_vars'):
anchor_vars = []
helper = LayerHelper('anchor_grid')
for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
stride = 2**l
anchors = self.make_grid(stride)
var = helper.create_parameter(
attr=ParamAttr(name='anchors_{}'.format(idx)),
shape=anchors.shape,
dtype='float32',
stop_gradient=True,
default_initializer=NumpyArrayInitializer(anchors))
anchor_vars.append(var)
var.persistable = True
self._anchor_vars = anchor_vars
return self._anchor_vars
@register
@serializable
class FCOSBox(object):
......
......@@ -20,8 +20,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..ops import iou_similarity
from ..bbox_utils import bbox2delta
from ..bbox_utils import iou_similarity, bbox2delta
__all__ = ['SSDLoss']
......
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import decode_yolo, xywh2xyxy, iou_similarity
from ..bbox_utils import decode_yolo, xywh2xyxy, batch_iou_similarity
__all__ = ['YOLOv3Loss']
......@@ -80,7 +80,7 @@ class YOLOv3Loss(nn.Layer):
gwh = gbox[:, :, 0:2] + gbox[:, :, 2:4] * 0.5
gbox = paddle.concat([gxy, gwh], axis=-1)
iou = iou_similarity(pbox, gbox)
iou = batch_iou_similarity(pbox, gbox)
iou.stop_gradient = True
iou_max = iou.max(2) # [N, M1]
iou_mask = paddle.cast(iou_max <= self.ignore_thresh, dtype=pbox.dtype)
......
此差异已折叠。
......@@ -18,9 +18,7 @@ import unittest
import contextlib
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Program
from paddle.fluid import core
from paddle.static import Program
class LayerTest(unittest.TestCase):
......@@ -35,19 +33,17 @@ class LayerTest(unittest.TestCase):
def _get_place(self, force_to_use_cpu=False):
# this option for ops that only have cpu kernel
if force_to_use_cpu:
return core.CPUPlace()
return 'cpu'
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
return paddle.device.get_device()
@contextlib.contextmanager
def static_graph(self):
paddle.enable_static()
scope = fluid.core.Scope()
scope = paddle.static.Scope()
program = Program()
with fluid.scope_guard(scope):
with fluid.program_guard(program):
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(program):
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
......@@ -57,9 +53,9 @@ class LayerTest(unittest.TestCase):
fetch_list,
with_lod=False,
force_to_use_cpu=False):
exe = fluid.Executor(self._get_place(force_to_use_cpu))
exe.run(fluid.default_startup_program())
return exe.run(fluid.default_main_program(),
exe = paddle.static.Executor(self._get_place(force_to_use_cpu))
exe.run(paddle.static.default_startup_program())
return exe.run(paddle.static.default_main_program(),
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
......@@ -67,8 +63,8 @@ class LayerTest(unittest.TestCase):
@contextlib.contextmanager
def dynamic_graph(self, force_to_use_cpu=False):
paddle.disable_static()
with fluid.dygraph.guard(
self._get_place(force_to_use_cpu=force_to_use_cpu)):
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
place = self._get_place(force_to_use_cpu=force_to_use_cpu)
paddle.device.set_device(place)
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
......@@ -23,8 +23,6 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import base
import ppdet.modeling.ops as ops
from ppdet.modeling.tests.test_base import LayerTest
......@@ -50,127 +48,6 @@ def softmax(x):
return exps / np.sum(exps)
class TestCollectFpnProposals(LayerTest):
def test_collect_fpn_proposals(self):
multi_bboxes_np = []
multi_scores_np = []
rois_num_per_level_np = []
for i in range(4):
bboxes_np = np.random.rand(5, 4).astype('float32')
scores_np = np.random.rand(5, 1).astype('float32')
rois_num = np.array([2, 3]).astype('int32')
multi_bboxes_np.append(bboxes_np)
multi_scores_np.append(scores_np)
rois_num_per_level_np.append(rois_num)
with self.static_graph():
multi_bboxes = []
multi_scores = []
rois_num_per_level = []
for i in range(4):
bboxes = paddle.static.data(
name='rois' + str(i),
shape=[5, 4],
dtype='float32',
lod_level=1)
scores = paddle.static.data(
name='scores' + str(i),
shape=[5, 1],
dtype='float32',
lod_level=1)
rois_num = paddle.static.data(
name='rois_num' + str(i), shape=[None], dtype='int32')
multi_bboxes.append(bboxes)
multi_scores.append(scores)
rois_num_per_level.append(rois_num)
fpn_rois, rois_num = ops.collect_fpn_proposals(
multi_bboxes,
multi_scores,
2,
5,
10,
rois_num_per_level=rois_num_per_level)
feed = {}
for i in range(4):
feed['rois' + str(i)] = multi_bboxes_np[i]
feed['scores' + str(i)] = multi_scores_np[i]
feed['rois_num' + str(i)] = rois_num_per_level_np[i]
fpn_rois_stat, rois_num_stat = self.get_static_graph_result(
feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True)
fpn_rois_stat = np.array(fpn_rois_stat)
rois_num_stat = np.array(rois_num_stat)
with self.dynamic_graph():
multi_bboxes_dy = []
multi_scores_dy = []
rois_num_per_level_dy = []
for i in range(4):
bboxes_dy = base.to_variable(multi_bboxes_np[i])
scores_dy = base.to_variable(multi_scores_np[i])
rois_num_dy = base.to_variable(rois_num_per_level_np[i])
multi_bboxes_dy.append(bboxes_dy)
multi_scores_dy.append(scores_dy)
rois_num_per_level_dy.append(rois_num_dy)
fpn_rois_dy, rois_num_dy = ops.collect_fpn_proposals(
multi_bboxes_dy,
multi_scores_dy,
2,
5,
10,
rois_num_per_level=rois_num_per_level_dy)
fpn_rois_dy = fpn_rois_dy.numpy()
rois_num_dy = rois_num_dy.numpy()
self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy))
self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy))
def test_collect_fpn_proposals_error(self):
def generate_input(bbox_type, score_type, name):
multi_bboxes = []
multi_scores = []
for i in range(4):
bboxes = paddle.static.data(
name='rois' + name + str(i),
shape=[10, 4],
dtype=bbox_type,
lod_level=1)
scores = paddle.static.data(
name='scores' + name + str(i),
shape=[10, 1],
dtype=score_type,
lod_level=1)
multi_bboxes.append(bboxes)
multi_scores.append(scores)
return multi_bboxes, multi_scores
with self.static_graph():
bbox1 = paddle.static.data(
name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1)
score1 = paddle.static.data(
name='scores', shape=[5, 10, 1], dtype='float32', lod_level=1)
bbox2, score2 = generate_input('int32', 'float32', '2')
self.assertRaises(
TypeError,
ops.collect_fpn_proposals,
multi_rois=bbox1,
multi_scores=score1,
min_level=2,
max_level=5,
post_nms_top_n=2000)
self.assertRaises(
TypeError,
ops.collect_fpn_proposals,
multi_rois=bbox2,
multi_scores=score2,
min_level=2,
max_level=5,
post_nms_top_n=2000)
paddle.disable_static()
class TestDistributeFpnProposals(LayerTest):
def test_distribute_fpn_proposals(self):
rois_np = np.random.rand(10, 4).astype('float32')
......@@ -200,8 +77,8 @@ class TestDistributeFpnProposals(LayerTest):
output_stat_np.append(output_np)
with self.dynamic_graph():
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
rois_dy = paddle.to_tensor(rois_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = ops.distribute_fpn_proposals(
fpn_rois=rois_dy,
min_level=2,
......@@ -266,9 +143,9 @@ class TestROIAlign(LayerTest):
with_lod=False)
with self.dynamic_graph():
inputs_dy = base.to_variable(inputs_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
inputs_dy = paddle.to_tensor(inputs_np)
rois_dy = paddle.to_tensor(rois_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
output_dy = ops.roi_align(
input=inputs_dy,
......@@ -326,9 +203,9 @@ class TestROIPool(LayerTest):
with_lod=False)
with self.dynamic_graph():
inputs_dy = base.to_variable(inputs_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
inputs_dy = paddle.to_tensor(inputs_np)
rois_dy = paddle.to_tensor(rois_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
output_dy, _ = ops.roi_pool(
input=inputs_dy,
......@@ -355,134 +232,6 @@ class TestROIPool(LayerTest):
paddle.disable_static()
class TestIoUSimilarity(LayerTest):
def test_iou_similarity(self):
b, c, h, w = 2, 12, 20, 20
inputs_np = np.random.rand(b, c, h, w).astype('float32')
output_size = (7, 7)
x_np = make_rois(h, w, [20], output_size)
y_np = make_rois(h, w, [10], output_size)
with self.static_graph():
x = paddle.static.data(name='x', shape=[20, 4], dtype='float32')
y = paddle.static.data(name='y', shape=[10, 4], dtype='float32')
iou = ops.iou_similarity(x=x, y=y)
iou_np, = self.get_static_graph_result(
feed={
'x': x_np,
'y': y_np,
}, fetch_list=[iou], with_lod=False)
with self.dynamic_graph():
x_dy = base.to_variable(x_np)
y_dy = base.to_variable(y_np)
iou_dy = ops.iou_similarity(x=x_dy, y=y_dy)
iou_dy_np = iou_dy.numpy()
self.assertTrue(np.array_equal(iou_np, iou_dy_np))
class TestBipartiteMatch(LayerTest):
def test_bipartite_match(self):
distance = np.random.random((20, 10)).astype('float32')
with self.static_graph():
x = paddle.static.data(name='x', shape=[20, 10], dtype='float32')
match_indices, match_dist = ops.bipartite_match(
x, match_type='per_prediction', dist_threshold=0.5)
match_indices_np, match_dist_np = self.get_static_graph_result(
feed={'x': distance, },
fetch_list=[match_indices, match_dist],
with_lod=False)
with self.dynamic_graph():
x_dy = base.to_variable(distance)
match_indices_dy, match_dist_dy = ops.bipartite_match(
x_dy, match_type='per_prediction', dist_threshold=0.5)
match_indices_dy_np = match_indices_dy.numpy()
match_dist_dy_np = match_dist_dy.numpy()
self.assertTrue(np.array_equal(match_indices_np, match_indices_dy_np))
self.assertTrue(np.array_equal(match_dist_np, match_dist_dy_np))
class TestYoloBox(LayerTest):
def test_yolo_box(self):
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
np_x = np.random.random([1, 30, 7, 7]).astype('float32')
np_origin_shape = np.array([[608, 608]], dtype='int32')
class_num = 10
conf_thresh = 0.01
downsample_ratio = 32
scale_x_y = 1.2
# static
with self.static_graph():
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
x = paddle.static.data(
name='x', shape=[1, 30, 7, 7], dtype='float32')
origin_shape = paddle.static.data(
name='origin_shape', shape=[1, 2], dtype='int32')
boxes, scores = ops.yolo_box(
x,
origin_shape, [10, 13, 30, 13],
class_num,
conf_thresh,
downsample_ratio,
scale_x_y=scale_x_y)
boxes_np, scores_np = self.get_static_graph_result(
feed={
'x': np_x,
'origin_shape': np_origin_shape,
},
fetch_list=[boxes, scores],
with_lod=False)
# dygraph
with self.dynamic_graph():
x_dy = fluid.layers.assign(np_x)
origin_shape_dy = fluid.layers.assign(np_origin_shape)
boxes_dy, scores_dy = ops.yolo_box(
x_dy,
origin_shape_dy, [10, 13, 30, 13],
10,
0.01,
32,
scale_x_y=scale_x_y)
boxes_dy_np = boxes_dy.numpy()
scores_dy_np = scores_dy.numpy()
self.assertTrue(np.array_equal(boxes_np, boxes_dy_np))
self.assertTrue(np.array_equal(scores_np, scores_dy_np))
def test_yolo_box_error(self):
with self.static_graph():
# x shape [N C H W], C=K * (5 + class_num), class_num=10, K=2
x = paddle.static.data(
name='x', shape=[1, 30, 7, 7], dtype='float32')
origin_shape = paddle.static.data(
name='origin_shape', shape=[1, 2], dtype='int32')
self.assertRaises(
TypeError,
ops.yolo_box,
x,
origin_shape, [10, 13, 30, 13],
10.123,
0.01,
32,
scale_x_y=1.2)
paddle.disable_static()
class TestPriorBox(LayerTest):
def test_prior_box(self):
input_np = np.random.rand(2, 10, 32, 32).astype('float32')
......@@ -509,8 +258,8 @@ class TestPriorBox(LayerTest):
with_lod=False)
with self.dynamic_graph():
inputs_dy = base.to_variable(input_np)
image_dy = base.to_variable(image_np)
inputs_dy = paddle.to_tensor(input_np)
image_dy = paddle.to_tensor(image_np)
box_dy, var_dy = ops.prior_box(
input=inputs_dy,
......@@ -582,9 +331,9 @@ class TestMulticlassNms(LayerTest):
nms_rois_num_np = np.array(nms_rois_num_np)
with self.dynamic_graph():
boxes_dy = base.to_variable(boxes_np)
scores_dy = base.to_variable(scores_np)
rois_num_dy = base.to_variable(rois_num_np)
boxes_dy = paddle.to_tensor(boxes_np)
scores_dy = paddle.to_tensor(scores_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
out_dy, index_dy, nms_rois_num_dy = ops.multiclass_nms(
bboxes=boxes_dy,
......@@ -666,8 +415,8 @@ class TestMatrixNMS(LayerTest):
with_lod=True)
with self.dynamic_graph():
boxes_dy = base.to_variable(boxes_np)
scores_dy = base.to_variable(scores_np)
boxes_dy = paddle.to_tensor(boxes_np)
scores_dy = paddle.to_tensor(scores_np)
out_dy, index_dy, _ = ops.matrix_nms(
bboxes=boxes_dy,
......@@ -737,9 +486,9 @@ class TestBoxCoder(LayerTest):
# dygraph
with self.dynamic_graph():
prior_box_dy = base.to_variable(prior_box_np)
prior_box_var_dy = base.to_variable(prior_box_var_np)
target_box_dy = base.to_variable(target_box_np)
prior_box_dy = paddle.to_tensor(prior_box_np)
prior_box_var_dy = paddle.to_tensor(prior_box_var_np)
target_box_dy = paddle.to_tensor(target_box_np)
boxes_dy = ops.box_coder(
prior_box=prior_box_dy,
......@@ -808,11 +557,11 @@ class TestGenerateProposals(LayerTest):
with_lod=True)
with self.dynamic_graph():
scores_dy = base.to_variable(scores_np)
bbox_deltas_dy = base.to_variable(bbox_deltas_np)
im_shape_dy = base.to_variable(im_shape_np)
anchors_dy = base.to_variable(anchors_np)
variances_dy = base.to_variable(variances_np)
scores_dy = paddle.to_tensor(scores_np)
bbox_deltas_dy = paddle.to_tensor(bbox_deltas_np)
im_shape_dy = paddle.to_tensor(im_shape_np)
anchors_dy = paddle.to_tensor(anchors_np)
variances_dy = paddle.to_tensor(variances_np)
rois, roi_probs, rois_num = ops.generate_proposals(
scores_dy,
bbox_deltas_dy,
......
......@@ -17,7 +17,7 @@ from __future__ import division
import unittest
import paddle
from paddle import fluid
import paddle.nn.functional as F
# add python path of PadleDetection to sys.path
import os
import sys
......@@ -27,19 +27,9 @@ if parent_path not in sys.path:
from ppdet.modeling.losses import YOLOv3Loss
from ppdet.data.transform.op_helper import jaccard_overlap
from ppdet.modeling.bbox_utils import iou_similarity
import numpy as np
def _split_ioup(output, an_num, num_classes):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
np.random.seed(0)
def _split_output(output, an_num, num_classes):
......@@ -47,31 +37,31 @@ def _split_output(output, an_num, num_classes):
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
x = paddle.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
y = paddle.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
w = paddle.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
h = paddle.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
obj = paddle.strided_slice(
output,
axes=[1],
starts=[4],
......@@ -81,14 +71,12 @@ def _split_output(output, an_num, num_classes):
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
paddle.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
cls = paddle.transpose(paddle.stack(clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
......@@ -104,7 +92,7 @@ def _split_target(target):
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls = paddle.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
......@@ -115,9 +103,9 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, prob = fluid.layers.yolo_box(
bbox, prob = paddle.vision.ops.yolo_box(
x=output,
img_size=fluid.layers.ones(
img_size=paddle.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
......@@ -128,8 +116,8 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
preds = paddle.split(bbox, batch_size, axis=0)
gts = paddle.split(gt_box, batch_size, axis=0)
else:
preds = [bbox]
gts = [gt_box]
......@@ -142,7 +130,7 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
return paddle.stack(
[
x - w / 2.,
y - h / 2.,
......@@ -150,28 +138,29 @@ def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
pred = paddle.squeeze(pred, axis=[0])
gt = box_xywh2xyxy(paddle.squeeze(gt, axis=[0]))
ious.append(iou_similarity(pred, gt))
iou = paddle.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
output_shape = fluid.layers.shape(output)
max_iou = paddle.max(iou, axis=-1)
iou_mask = paddle.cast(max_iou <= ignore_thresh, dtype="float32")
output_shape = paddle.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask = paddle.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
obj_sigmoid = F.sigmoid(obj)
loss_obj = F.binary_cross_entropy(obj_sigmoid, obj_mask, reduction='none')
loss_obj_pos = paddle.sum(loss_obj * tobj, axis=[1, 2, 3])
loss_obj_neg = paddle.sum(loss_obj * (1.0 - obj_mask) * iou_mask,
axis=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
......@@ -194,45 +183,48 @@ def fine_grained_loss(output,
scale_x_y = scale_x_y
if (abs(scale_x_y - 1.0) < eps):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
x = F.sigmoid(x)
y = F.sigmoid(y)
loss_x = F.binary_cross_entropy(x, tx, reduction='none') * tscale_tobj
loss_x = paddle.sum(loss_x, axis=[1, 2, 3])
loss_y = F.binary_cross_entropy(y, ty, reduction='none') * tscale_tobj
loss_y = paddle.sum(loss_y, axis=[1, 2, 3])
else:
dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
dx = scale_x_y * F.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
dy = scale_x_y * F.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
loss_x = paddle.abs(dx - tx) * tscale_tobj
loss_x = paddle.sum(loss_x, axis=[1, 2, 3])
loss_y = paddle.abs(dy - ty) * tscale_tobj
loss_y = paddle.sum(loss_y, axis=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
loss_w = paddle.abs(w - tw) * tscale_tobj
loss_w = paddle.sum(loss_w, axis=[1, 2, 3])
loss_h = paddle.abs(h - th) * tscale_tobj
loss_h = paddle.sum(loss_h, axis=[1, 2, 3])
loss_obj_pos, loss_obj_neg = _calc_obj_loss(
output, obj, tobj, gt_box, batch_size, anchors, num_classes, downsample,
ignore_thresh, scale_x_y)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
cls = F.sigmoid(cls)
loss_cls = F.binary_cross_entropy(cls, tcls, reduction='none')
tobj = paddle.unsqueeze(tobj, axis=-1)
loss_cls = paddle.multiply(loss_cls, tobj)
loss_cls = paddle.sum(loss_cls, axis=[1, 2, 3, 4])
loss_xys = fluid.layers.reduce_mean(loss_x + loss_y)
loss_whs = fluid.layers.reduce_mean(loss_w + loss_h)
loss_objs = fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)
loss_clss = fluid.layers.reduce_mean(loss_cls)
loss_xys = paddle.mean(loss_x + loss_y)
loss_whs = paddle.mean(loss_w + loss_h)
loss_objs = paddle.mean(loss_obj_pos + loss_obj_neg)
loss_clss = paddle.mean(loss_cls)
losses_all = {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_loc": fluid.layers.sum(loss_xys) + fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
"loss_xy": paddle.sum(loss_xys),
"loss_wh": paddle.sum(loss_whs),
"loss_loc": paddle.sum(loss_xys) + paddle.sum(loss_whs),
"loss_obj": paddle.sum(loss_objs),
"loss_cls": paddle.sum(loss_clss),
}
return losses_all, x, y, tx, ty
......
......@@ -20,7 +20,7 @@ import sys
import paddle
import six
import paddle.version as fluid_version
import paddle.version as paddle_version
from .logger import setup_logger
logger = setup_logger(__name__)
......@@ -97,8 +97,8 @@ def check_version(version='2.0'):
"Please make sure the version is good with your code.".format(version)
version_installed = [
fluid_version.major, fluid_version.minor, fluid_version.patch,
fluid_version.rc
paddle_version.major, paddle_version.minor, paddle_version.patch,
paddle_version.rc
]
if version_installed == ['0', '0', '0', '0']:
return
......
......@@ -26,17 +26,31 @@ extern "C" {
*/
JNIEXPORT jlong JNICALL
Java_com_baidu_paddledetection_detection_Native_nativeInit(
JNIEnv *env, jclass thiz, jstring jModelDir, jstring jLabelPath,
jint cpuThreadNum, jstring jCPUPowerMode, jint inputWidth, jint inputHeight,
jfloatArray jInputMean, jfloatArray jInputStd, jfloat scoreThreshold) {
JNIEnv *env,
jclass thiz,
jstring jModelDir,
jstring jLabelPath,
jint cpuThreadNum,
jstring jCPUPowerMode,
jint inputWidth,
jint inputHeight,
jfloatArray jInputMean,
jfloatArray jInputStd,
jfloat scoreThreshold) {
std::string modelDir = jstring_to_cpp_string(env, jModelDir);
std::string labelPath = jstring_to_cpp_string(env, jLabelPath);
std::string cpuPowerMode = jstring_to_cpp_string(env, jCPUPowerMode);
std::vector<float> inputMean = jfloatarray_to_float_vector(env, jInputMean);
std::vector<float> inputStd = jfloatarray_to_float_vector(env, jInputStd);
return reinterpret_cast<jlong>(
new Pipeline(modelDir, labelPath, cpuThreadNum, cpuPowerMode, inputWidth,
inputHeight, inputMean, inputStd, scoreThreshold));
return reinterpret_cast<jlong>(new Pipeline(modelDir,
labelPath,
cpuThreadNum,
cpuPowerMode,
inputWidth,
inputHeight,
inputMean,
inputStd,
scoreThreshold));
}
/*
......@@ -45,8 +59,9 @@ Java_com_baidu_paddledetection_detection_Native_nativeInit(
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddledetection_detection_Native_nativeRelease(
JNIEnv *env, jclass thiz, jlong ctx) {
Java_com_baidu_paddledetection_detection_Native_nativeRelease(JNIEnv *env,
jclass thiz,
jlong ctx) {
if (ctx == 0) {
return JNI_FALSE;
}
......@@ -62,15 +77,21 @@ Java_com_baidu_paddledetection_detection_Native_nativeRelease(
*/
JNIEXPORT jboolean JNICALL
Java_com_baidu_paddledetection_detection_Native_nativeProcess(
JNIEnv *env, jclass thiz, jlong ctx, jint inTextureId, jint outTextureId,
jint textureWidth, jint textureHeight, jstring jsavedImagePath) {
JNIEnv *env,
jclass thiz,
jlong ctx,
jint inTextureId,
jint outTextureId,
jint textureWidth,
jint textureHeight,
jstring jsavedImagePath) {
if (ctx == 0) {
return JNI_FALSE;
}
std::string savedImagePath = jstring_to_cpp_string(env, jsavedImagePath);
Pipeline *pipeline = reinterpret_cast<Pipeline *>(ctx);
return pipeline->Process(inTextureId, outTextureId, textureWidth,
textureHeight, savedImagePath);
return pipeline->Process(
inTextureId, outTextureId, textureWidth, textureHeight, savedImagePath);
}
#ifdef __cplusplus
......
......@@ -50,8 +50,8 @@ inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) {
env->GetMethodID(strClass, "<init>", "([BLjava/lang/String;)V");
jbyteArray bytes = env->NewByteArray(strlen(data));
env->SetByteArrayRegion(bytes, 0, strlen(data),
reinterpret_cast<const jbyte *>(data));
env->SetByteArrayRegion(
bytes, 0, strlen(data), reinterpret_cast<const jbyte *>(data));
jstring encoding = env->NewStringUTF("UTF-8");
jstring res = (jstring)(
......@@ -64,21 +64,24 @@ inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) {
return res;
}
inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf,
inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env,
const float *buf,
int64_t len) {
jfloatArray result = env->NewFloatArray(len);
env->SetFloatArrayRegion(result, 0, len, buf);
return result;
}
inline jintArray cpp_array_to_jintarray(JNIEnv *env, const int *buf,
inline jintArray cpp_array_to_jintarray(JNIEnv *env,
const int *buf,
int64_t len) {
jintArray result = env->NewIntArray(len);
env->SetIntArrayRegion(result, 0, len, buf);
return result;
}
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, const int8_t *buf,
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env,
const int8_t *buf,
int64_t len) {
jbyteArray result = env->NewByteArray(len);
env->SetByteArrayRegion(result, 0, len, buf);
......
......@@ -14,13 +14,20 @@
#include "Pipeline.h"
Detector::Detector(const std::string &modelDir, const std::string &labelPath,
const int cpuThreadNum, const std::string &cpuPowerMode,
int inputWidth, int inputHeight,
Detector::Detector(const std::string &modelDir,
const std::string &labelPath,
const int cpuThreadNum,
const std::string &cpuPowerMode,
int inputWidth,
int inputHeight,
const std::vector<float> &inputMean,
const std::vector<float> &inputStd, float scoreThreshold)
: inputWidth_(inputWidth), inputHeight_(inputHeight), inputMean_(inputMean),
inputStd_(inputStd), scoreThreshold_(scoreThreshold) {
const std::vector<float> &inputStd,
float scoreThreshold)
: inputWidth_(inputWidth),
inputHeight_(inputHeight),
inputMean_(inputMean),
inputStd_(inputStd),
scoreThreshold_(scoreThreshold) {
paddle::lite_api::MobileConfig config;
config.set_model_from_file(modelDir + "/model.nb");
config.set_threads(cpuThreadNum);
......@@ -71,13 +78,16 @@ void Detector::Preprocess(const cv::Mat &rgbaImage) {
inputTensor->Resize(inputShape);
auto inputData = inputTensor->mutable_data<float>();
cv::Mat resizedRGBAImage;
cv::resize(rgbaImage, resizedRGBAImage,
cv::Size(inputShape[3], inputShape[2]));
cv::resize(
rgbaImage, resizedRGBAImage, cv::Size(inputShape[3], inputShape[2]));
cv::Mat resizedRGBImage;
cv::cvtColor(resizedRGBAImage, resizedRGBImage, cv::COLOR_BGRA2RGB);
resizedRGBImage.convertTo(resizedRGBImage, CV_32FC3, 1.0 / 255.0f);
NHWC3ToNC3HW(reinterpret_cast<const float *>(resizedRGBImage.data), inputData,
inputMean_.data(), inputStd_.data(), inputShape[3],
NHWC3ToNC3HW(reinterpret_cast<const float *>(resizedRGBImage.data),
inputData,
inputMean_.data(),
inputStd_.data(),
inputShape[3],
inputShape[2]);
// Set the size of input image
auto sizeTensor = predictor_->GetInput(1);
......@@ -97,8 +107,7 @@ void Detector::Postprocess(std::vector<RESULT> *results) {
auto class_id = static_cast<int>(round(outputData[i]));
// Confidence score
auto score = outputData[i + 1];
if (score < scoreThreshold_)
continue;
if (score < scoreThreshold_) continue;
RESULT object;
object.class_name = class_id >= 0 && class_id < labelList_.size()
? labelList_[class_id]
......@@ -115,8 +124,10 @@ void Detector::Postprocess(std::vector<RESULT> *results) {
}
}
void Detector::Predict(const cv::Mat &rgbaImage, std::vector<RESULT> *results,
double *preprocessTime, double *predictTime,
void Detector::Predict(const cv::Mat &rgbaImage,
std::vector<RESULT> *results,
double *preprocessTime,
double *predictTime,
double *postprocessTime) {
auto t = GetCurrentTime();
......@@ -136,13 +147,23 @@ void Detector::Predict(const cv::Mat &rgbaImage, std::vector<RESULT> *results,
LOGD("Detector postprocess costs %f ms", *postprocessTime);
}
Pipeline::Pipeline(const std::string &modelDir, const std::string &labelPath,
const int cpuThreadNum, const std::string &cpuPowerMode,
int inputWidth, int inputHeight,
Pipeline::Pipeline(const std::string &modelDir,
const std::string &labelPath,
const int cpuThreadNum,
const std::string &cpuPowerMode,
int inputWidth,
int inputHeight,
const std::vector<float> &inputMean,
const std::vector<float> &inputStd, float scoreThreshold) {
detector_.reset(new Detector(modelDir, labelPath, cpuThreadNum, cpuPowerMode,
inputWidth, inputHeight, inputMean, inputStd,
const std::vector<float> &inputStd,
float scoreThreshold) {
detector_.reset(new Detector(modelDir,
labelPath,
cpuThreadNum,
cpuPowerMode,
inputWidth,
inputHeight,
inputMean,
inputStd,
scoreThreshold));
}
......@@ -169,15 +190,24 @@ void Pipeline::VisualizeResults(const std::vector<RESULT> &results,
cv::Point2d(boundingBox.x,
boundingBox.y - round(textSize.height * 1.25f)),
cv::Point2d(boundingBox.x + boundingBox.width, boundingBox.y),
object.fill_color, -1);
cv::putText(*rgbaImage, text, cv::Point2d(boundingBox.x, boundingBox.y),
fontFace, fontScale, cv::Scalar(255, 255, 255), fontThickness);
object.fill_color,
-1);
cv::putText(*rgbaImage,
text,
cv::Point2d(boundingBox.x, boundingBox.y),
fontFace,
fontScale,
cv::Scalar(255, 255, 255),
fontThickness);
}
}
void Pipeline::VisualizeStatus(double readGLFBOTime, double writeGLTextureTime,
double preprocessTime, double predictTime,
double postprocessTime, cv::Mat *rgbaImage) {
void Pipeline::VisualizeStatus(double readGLFBOTime,
double writeGLTextureTime,
double preprocessTime,
double predictTime,
double postprocessTime,
cv::Mat *rgbaImage) {
char text[255];
cv::Scalar fontColor = cv::Scalar(255, 255, 255);
int fontFace = cv::FONT_HERSHEY_PLAIN;
......@@ -188,47 +218,54 @@ void Pipeline::VisualizeStatus(double readGLFBOTime, double writeGLTextureTime,
cv::getTextSize(text, fontFace, fontScale, fontThickness, nullptr);
textSize.height *= 1.25f;
cv::Point2d offset(10, textSize.height + 15);
cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor,
fontThickness);
cv::putText(
*rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness);
sprintf(text, "Write GLTexture time: %.1f ms", writeGLTextureTime);
offset.y += textSize.height;
cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor,
fontThickness);
cv::putText(
*rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness);
sprintf(text, "Preprocess time: %.1f ms", preprocessTime);
offset.y += textSize.height;
cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor,
fontThickness);
cv::putText(
*rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness);
sprintf(text, "Predict time: %.1f ms", predictTime);
offset.y += textSize.height;
cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor,
fontThickness);
cv::putText(
*rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness);
sprintf(text, "Postprocess time: %.1f ms", postprocessTime);
offset.y += textSize.height;
cv::putText(*rgbaImage, text, offset, fontFace, fontScale, fontColor,
fontThickness);
cv::putText(
*rgbaImage, text, offset, fontFace, fontScale, fontColor, fontThickness);
}
bool Pipeline::Process(int inTexureId, int outTextureId, int textureWidth,
int textureHeight, std::string savedImagePath) {
bool Pipeline::Process(int inTexureId,
int outTextureId,
int textureWidth,
int textureHeight,
std::string savedImagePath) {
static double readGLFBOTime = 0, writeGLTextureTime = 0;
double preprocessTime = 0, predictTime = 0, postprocessTime = 0;
// Read pixels from FBO texture to CV image
cv::Mat rgbaImage;
CreateRGBAImageFromGLFBOTexture(textureWidth, textureHeight, &rgbaImage,
&readGLFBOTime);
CreateRGBAImageFromGLFBOTexture(
textureWidth, textureHeight, &rgbaImage, &readGLFBOTime);
// Feed the image, run inference and parse the results
std::vector<RESULT> results;
detector_->Predict(rgbaImage, &results, &preprocessTime, &predictTime,
&postprocessTime);
detector_->Predict(
rgbaImage, &results, &preprocessTime, &predictTime, &postprocessTime);
// Visualize the objects to the origin image
VisualizeResults(results, &rgbaImage);
// Visualize the status(performance data) to the origin image
VisualizeStatus(readGLFBOTime, writeGLTextureTime, preprocessTime,
predictTime, postprocessTime, &rgbaImage);
VisualizeStatus(readGLFBOTime,
writeGLTextureTime,
preprocessTime,
predictTime,
postprocessTime,
&rgbaImage);
// Dump modified image if savedImagePath is set
if (!savedImagePath.empty()) {
......
......@@ -14,8 +14,6 @@
#pragma once
#include "Utils.h"
#include "paddle_api.h"
#include <EGL/egl.h>
#include <GLES2/gl2.h>
#include <opencv2/core.hpp>
......@@ -24,6 +22,8 @@
#include <opencv2/imgproc.hpp>
#include <string>
#include <vector>
#include "Utils.h"
#include "paddle_api.h"
struct RESULT {
std::string class_name;
......@@ -36,24 +36,30 @@ struct RESULT {
};
class Detector {
public:
explicit Detector(const std::string &modelDir, const std::string &labelPath,
const int cpuThreadNum, const std::string &cpuPowerMode,
int inputWidth, int inputHeight,
public:
explicit Detector(const std::string &modelDir,
const std::string &labelPath,
const int cpuThreadNum,
const std::string &cpuPowerMode,
int inputWidth,
int inputHeight,
const std::vector<float> &inputMean,
const std::vector<float> &inputStd, float scoreThreshold);
const std::vector<float> &inputStd,
float scoreThreshold);
void Predict(const cv::Mat &rgbImage, std::vector<RESULT> *results,
double *preprocessTime, double *predictTime,
void Predict(const cv::Mat &rgbImage,
std::vector<RESULT> *results,
double *preprocessTime,
double *predictTime,
double *postprocessTime);
private:
private:
std::vector<std::string> LoadLabelList(const std::string &path);
std::vector<cv::Scalar> GenerateColorMap(int numOfClasses);
void Preprocess(const cv::Mat &rgbaImage);
void Postprocess(std::vector<RESULT> *results);
private:
private:
int inputWidth_;
int inputHeight_;
std::vector<float> inputMean_;
......@@ -65,36 +71,58 @@ private:
};
class Pipeline {
public:
Pipeline(const std::string &modelDir, const std::string &labelPath,
const int cpuThreadNum, const std::string &cpuPowerMode,
int inputWidth, int inputHeight, const std::vector<float> &inputMean,
const std::vector<float> &inputStd, float scoreThreshold);
public:
Pipeline(const std::string &modelDir,
const std::string &labelPath,
const int cpuThreadNum,
const std::string &cpuPowerMode,
int inputWidth,
int inputHeight,
const std::vector<float> &inputMean,
const std::vector<float> &inputStd,
float scoreThreshold);
bool Process(int inTextureId, int outTextureId, int textureWidth,
int textureHeight, std::string savedImagePath);
bool Process(int inTextureId,
int outTextureId,
int textureWidth,
int textureHeight,
std::string savedImagePath);
private:
private:
// Read pixels from FBO texture to CV image
void CreateRGBAImageFromGLFBOTexture(int textureWidth, int textureHeight,
void CreateRGBAImageFromGLFBOTexture(int textureWidth,
int textureHeight,
cv::Mat *rgbaImage,
double *readGLFBOTime) {
auto t = GetCurrentTime();
rgbaImage->create(textureHeight, textureWidth, CV_8UC4);
glReadPixels(0, 0, textureWidth, textureHeight, GL_RGBA, GL_UNSIGNED_BYTE,
glReadPixels(0,
0,
textureWidth,
textureHeight,
GL_RGBA,
GL_UNSIGNED_BYTE,
rgbaImage->data);
*readGLFBOTime = GetElapsedTime(t);
LOGD("Read from FBO texture costs %f ms", *readGLFBOTime);
}
// Write back to texture2D
void WriteRGBAImageBackToGLTexture(const cv::Mat &rgbaImage, int textureId,
void WriteRGBAImageBackToGLTexture(const cv::Mat &rgbaImage,
int textureId,
double *writeGLTextureTime) {
auto t = GetCurrentTime();
glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, textureId);
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, rgbaImage.cols, rgbaImage.rows,
GL_RGBA, GL_UNSIGNED_BYTE, rgbaImage.data);
glTexSubImage2D(GL_TEXTURE_2D,
0,
0,
0,
rgbaImage.cols,
rgbaImage.rows,
GL_RGBA,
GL_UNSIGNED_BYTE,
rgbaImage.data);
*writeGLTextureTime = GetElapsedTime(t);
LOGD("Write back to texture2D costs %f ms", *writeGLTextureTime);
}
......@@ -103,10 +131,13 @@ private:
void VisualizeResults(const std::vector<RESULT> &results, cv::Mat *rgbaImage);
// Visualize the status(performace data) to origin image
void VisualizeStatus(double readGLFBOTime, double writeGLTextureTime,
double preprocessTime, double predictTime,
double postprocessTime, cv::Mat *rgbaImage);
void VisualizeStatus(double readGLFBOTime,
double writeGLTextureTime,
double preprocessTime,
double predictTime,
double postprocessTime,
cv::Mat *rgbaImage);
private:
private:
std::shared_ptr<Detector> detector_;
};
......@@ -17,13 +17,16 @@
int64_t ShapeProduction(const std::vector<int64_t> &shape) {
int64_t res = 1;
for (auto i : shape)
res *= i;
for (auto i : shape) res *= i;
return res;
}
void NHWC3ToNC3HW(const float *src, float *dst, const float *mean,
const float *std, int width, int height) {
void NHWC3ToNC3HW(const float *src,
float *dst,
const float *mean,
const float *std,
int width,
int height) {
int size = height * width;
float32x4_t vmean0 = vdupq_n_f32(mean ? mean[0] : 0.0f);
float32x4_t vmean1 = vdupq_n_f32(mean ? mean[1] : 0.0f);
......@@ -58,8 +61,12 @@ void NHWC3ToNC3HW(const float *src, float *dst, const float *mean,
}
}
void NHWC1ToNC1HW(const float *src, float *dst, const float *mean,
const float *std, int width, int height) {
void NHWC1ToNC1HW(const float *src,
float *dst,
const float *mean,
const float *std,
int width,
int height) {
int size = height * width;
float32x4_t vmean = vdupq_n_f32(mean ? mean[0] : 0.0f);
float32x4_t vscale = vdupq_n_f32(std ? (1.0f / std[0]) : 1.0f);
......
......@@ -14,11 +14,11 @@
#pragma once
#include "paddle_api.h"
#include <android/log.h>
#include <fstream>
#include <string>
#include <vector>
#include "paddle_api.h"
#define TAG "JNI"
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__)
......@@ -85,8 +85,16 @@ inline paddle::lite_api::PowerMode ParsePowerMode(std::string mode) {
return paddle::lite_api::LITE_POWER_NO_BIND;
}
void NHWC3ToNC3HW(const float *src, float *dst, const float *mean,
const float *std, int width, int height);
void NHWC3ToNC3HW(const float *src,
float *dst,
const float *mean,
const float *std,
int width,
int height);
void NHWC1ToNC1HW(const float *src, float *dst, const float *mean,
const float *std, int width, int height);
void NHWC1ToNC1HW(const float *src,
float *dst,
const float *mean,
const float *std,
int width,
int height);
......@@ -15,9 +15,9 @@
#pragma once
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h"
......@@ -47,8 +47,7 @@ class ConfigPaser {
mode_ = config["mode"].as<std::string>();
} else {
std::cerr << "Please set mode, "
<< "support value : fluid/trt_fp16/trt_fp32."
<< std::endl;
<< "support value : fluid/trt_fp16/trt_fp32." << std::endl;
return false;
}
......@@ -110,4 +109,3 @@ class ConfigPaser {
};
} // namespace PaddleDetection
......@@ -14,21 +14,20 @@
#pragma once
#include <string>
#include <vector>
#include <ctime>
#include <memory>
#include <string>
#include <utility>
#include <ctime>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "paddle_inference_api.h" // NOLINT
#include "paddle_inference_api.h" // NOLINT
#include "include/preprocess_op.h"
#include "include/config_parser.h"
#include "include/preprocess_op.h"
namespace PaddleDetection {
// Object Detection Result
......@@ -41,48 +40,50 @@ struct ObjectResult {
float confidence;
};
// Generate visualization colormap for each class
std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result
cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list,
const std::vector<int>& colormap);
const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list,
const std::vector<int>& colormap);
class ObjectDetector {
public:
explicit ObjectDetector(const std::string& model_dir,
explicit ObjectDetector(const std::string& model_dir,
const std::string& device,
const std::string& run_mode="fluid",
const int gpu_id=0,
bool trt_calib_mode=false) {
const std::string& run_mode = "fluid",
const int gpu_id = 0,
bool trt_calib_mode = false) {
config_.load_config(model_dir);
threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_, config_.arch_);
LoadModel(model_dir, device, config_.min_subgraph_size_, 1, run_mode, gpu_id, trt_calib_mode);
LoadModel(model_dir,
device,
config_.min_subgraph_size_,
1,
run_mode,
gpu_id,
trt_calib_mode);
}
// Load Paddle inference model
void LoadModel(
const std::string& model_dir,
const std::string& device,
const int min_subgraph_size,
const int batch_size = 1,
const std::string& run_mode = "fluid",
const int gpu_id=0,
bool trt_calib_mode=false);
void LoadModel(const std::string& model_dir,
const std::string& device,
const int min_subgraph_size,
const int batch_size = 1,
const std::string& run_mode = "fluid",
const int gpu_id = 0,
bool trt_calib_mode = false);
// Run predictor
void Predict(const cv::Mat& im,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
const bool run_benchmark = false,
std::vector<ObjectResult>* result = nullptr);
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
const bool run_benchmark = false,
std::vector<ObjectResult>* result = nullptr);
// Get Model Label list
const std::vector<std::string>& GetLabelList() const {
......@@ -93,9 +94,7 @@ class ObjectDetector {
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(
const cv::Mat& raw_mat,
std::vector<ObjectResult>* result);
void Postprocess(const cv::Mat& raw_mat, std::vector<ObjectResult>* result);
std::unique_ptr<paddle::PaddlePredictor> predictor_;
Preprocessor preprocessor_;
......
......@@ -16,15 +16,15 @@
#include <yaml-cpp/yaml.h>
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
namespace PaddleDetection {
......@@ -38,7 +38,7 @@ class ImageBlob {
// Original image width, height, shrink in float format
std::vector<float> ori_im_size_f_;
// Evaluation image width and height
std::vector<float> eval_im_size_f_;
std::vector<float> eval_im_size_f_;
// Scale factor for image size to origin image size
std::vector<float> scale_factor_f_;
};
......@@ -50,7 +50,7 @@ class PreprocessOp {
virtual void Run(cv::Mat* im, ImageBlob* data) = 0;
};
class InitInfo : public PreprocessOp{
class InitInfo : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
......@@ -78,8 +78,8 @@ class Normalize : public PreprocessOp {
class Permute : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item, const std::string& arch) {
to_bgr_ = item["to_bgr"].as<bool>();
is_channel_first_ = item["channel_first"].as<bool>();
to_bgr_ = item["to_bgr"].as<bool>();
is_channel_first_ = item["channel_first"].as<bool>();
}
virtual void Run(cv::Mat* im, ImageBlob* data);
......@@ -97,11 +97,11 @@ class Resize : public PreprocessOp {
arch_ = arch;
interp_ = item["interp"].as<int>();
max_size_ = item["max_size"].as<int>();
if (item["image_shape"].IsDefined()) {
image_shape_ = item["image_shape"].as<std::vector<int>>();
if (item["image_shape"].IsDefined()) {
image_shape_ = item["image_shape"].as<std::vector<int>>();
}
target_size_ = item["target_size"].as<int>();
}
}
// Compute best resize scale for x-dimension, y-dimension
std::pair<float, float> GenerateScale(const cv::Mat& im);
......@@ -166,4 +166,3 @@ class Preprocessor {
};
} // namespace PaddleDetection
......@@ -14,12 +14,12 @@
#include <glog/logging.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
#include <sys/types.h>
#include <sys/stat.h>
#include <algorithm>
#ifdef _WIN32
#include <direct.h>
......@@ -29,25 +29,35 @@
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
#include <gflags/gflags.h>
#include "include/object_detector.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.");
DEFINE_bool(
use_gpu,
false,
"Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device,
"CPU",
"Choose the device you want to run, it can be: CPU/GPU/XPU, "
"default is CPU.");
DEFINE_bool(use_camera, false, "Use camera or not");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(run_benchmark,
false,
"Whether to predict a image_file repeatedly for benchmark");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True");
DEFINE_bool(trt_calib_mode,
false,
"If the model is produced by TRT offline quantitative calibration, "
"trt_calib_mode need to set True");
static std::string DirName(const std::string &filepath) {
static std::string DirName(const std::string& filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
......@@ -55,7 +65,7 @@ static std::string DirName(const std::string &filepath) {
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
static bool PathExists(const std::string& path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
......@@ -92,9 +102,9 @@ void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) {
// Open video
cv::VideoCapture capture;
if (FLAGS_camera_id != -1){
if (FLAGS_camera_id != -1) {
capture.open(FLAGS_camera_id);
}else{
} else {
capture.open(video_path.c_str());
}
if (!capture.isOpened()) {
......@@ -131,18 +141,20 @@ void PredictVideo(const std::string& video_path,
break;
}
det->Predict(frame, 0.5, 0, 1, false, &result);
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap);
cv::Mat out_im =
PaddleDetection::VisualizeResult(frame, result, labels, colormap);
for (const auto& item : result) {
printf("In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d %d]\n",
frame_id,
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
printf(
"In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d "
"%d]\n",
frame_id,
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
video_out.write(out_im);
frame_id += 1;
}
......@@ -159,26 +171,24 @@ void PredictImage(const std::string& image_path,
cv::Mat im = cv::imread(image_path, 1);
// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
if (run_benchmark)
{
if (run_benchmark) {
det->Predict(im, threshold, 100, 100, run_benchmark, &result);
}else
{
} else {
det->Predict(im, 0.5, 0, 1, run_benchmark, &result);
for (const auto& item : result) {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
// Visualization result
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, result, labels, colormap);
cv::Mat vis_img =
PaddleDetection::VisualizeResult(im, result, labels, colormap);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
......@@ -195,30 +205,39 @@ void PredictImage(const std::string& image_path,
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty()
|| (FLAGS_image_file.empty() && FLAGS_video_path.empty())) {
if (FLAGS_model_dir.empty() ||
(FLAGS_image_file.empty() && FLAGS_video_path.empty())) {
std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ "
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32"
|| FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout
<< "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) {
transform(FLAGS_device.begin(),
FLAGS_device.end(),
FLAGS_device.begin(),
::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" ||
FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
if (FLAGS_use_gpu) {
std::cout << "Deprecated, please use `--device` to set the device you want to run.";
std::cout << "Deprecated, please use `--device` to set the device you want "
"to run.";
return -1;
}
// Load model and create a object detector
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_device,
FLAGS_run_mode, FLAGS_gpu_id, FLAGS_trt_calib_mode);
PaddleDetection::ObjectDetector det(FLAGS_model_dir,
FLAGS_device,
FLAGS_run_mode,
FLAGS_gpu_id,
FLAGS_trt_calib_mode);
// Do inference on input video or image
if (!FLAGS_video_path.empty() || FLAGS_use_camera) {
PredictVideo(FLAGS_video_path, &det);
......@@ -226,7 +245,11 @@ int main(int argc, char** argv) {
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
PredictImage(FLAGS_image_file, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
PredictImage(FLAGS_image_file,
FLAGS_threshold,
FLAGS_run_benchmark,
&det,
FLAGS_output_dir);
}
return 0;
}
......@@ -13,8 +13,8 @@
// limitations under the License.
#include <sstream>
// for setprecision
#include <iomanip>
#include <chrono>
#include <iomanip>
#include "include/object_detector.h"
namespace PaddleDetection {
......@@ -41,18 +41,18 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
} else if (run_mode == "trt_int8") {
precision = paddle::AnalysisConfig::Precision::kInt8;
} else {
printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
printf(
"run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
}
config.EnableTensorRtEngine(
1 << 10,
batch_size,
min_subgraph_size,
precision,
false,
trt_calib_mode);
}
} else if (device == "XPU"){
config.EnableXpu(10*1024*1024);
config.EnableTensorRtEngine(1 << 10,
batch_size,
min_subgraph_size,
precision,
false,
trt_calib_mode);
}
} else if (device == "XPU") {
config.EnableXpu(10 * 1024 * 1024);
} else {
config.DisableGpu();
}
......@@ -88,11 +88,8 @@ cv::Mat VisualizeResult(const cv::Mat& img,
int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
double font_scale = 0.5f;
float thickness = 0.5;
cv::Size text_size = cv::getTextSize(text,
font_face,
font_scale,
thickness,
nullptr);
cv::Size text_size =
cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
cv::Point origin;
origin.x = roi.x;
origin.y = roi.y;
......@@ -124,9 +121,8 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
preprocessor_.Run(&im, &inputs_);
}
void ObjectDetector::Postprocess(
const cv::Mat& raw_mat,
std::vector<ObjectResult>* result) {
void ObjectDetector::Postprocess(const cv::Mat& raw_mat,
std::vector<ObjectResult>* result) {
result->clear();
int rh = 1;
int rw = 1;
......@@ -158,11 +154,11 @@ void ObjectDetector::Postprocess(
}
void ObjectDetector::Predict(const cv::Mat& im,
const double threshold,
const int warmup,
const int repeats,
const bool run_benchmark,
std::vector<ObjectResult>* result) {
const double threshold,
const int warmup,
const int repeats,
const bool run_benchmark,
std::vector<ObjectResult>* result) {
// Preprocess image
Preprocess(im);
// Prepare input tensor
......@@ -189,8 +185,7 @@ void ObjectDetector::Predict(const cv::Mat& im,
}
}
// Run predictor
for (int i = 0; i < warmup; i++)
{
for (int i = 0; i < warmup; i++) {
predictor_->ZeroCopyRun();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
......@@ -206,12 +201,11 @@ void ObjectDetector::Predict(const cv::Mat& im,
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
out_tensor->copy_to_cpu(output_data_.data());
}
auto start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++)
{
for (int i = 0; i < repeats; i++) {
predictor_->ZeroCopyRun();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
......@@ -227,15 +221,15 @@ void ObjectDetector::Predict(const cv::Mat& im,
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
out_tensor->copy_to_cpu(output_data_.data());
}
auto end = std::chrono::steady_clock::now();
std::chrono::duration<float> diff = end - start;
float ms = diff.count() / repeats * 1000;
printf("Inference: %f ms per batch image\n", ms);
// Postprocessing result
if(!run_benchmark) {
Postprocess(im, result);
if (!run_benchmark) {
Postprocess(im, result);
}
}
......
......@@ -12,28 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <string>
#include <vector>
#include "include/preprocess_op.h"
namespace PaddleDetection {
void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
data->ori_im_size_ = {
static_cast<int>(im->rows),
static_cast<int>(im->cols)
};
data->ori_im_size_ = {static_cast<int>(im->rows), static_cast<int>(im->cols)};
data->ori_im_size_f_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
1.0
};
static_cast<float>(im->rows), static_cast<float>(im->cols), 1.0};
data->eval_im_size_f_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
1.0
};
static_cast<float>(im->rows), static_cast<float>(im->cols), 1.0};
data->scale_factor_f_ = {1., 1., 1., 1.};
}
......@@ -46,11 +37,11 @@ void Normalize::Run(cv::Mat* im, ImageBlob* data) {
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean_[0] ) / scale_[0];
(im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean_[1] ) / scale_[1];
(im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean_[2] ) / scale_[2];
(im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2];
}
}
}
......@@ -63,7 +54,8 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) {
float* base = (data->im_data_).data();
for (int i = 0; i < rc; ++i) {
int cur_c = to_bgr_ ? rc - i - 1 : i;
cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, base + cur_c * rh * rw), i);
cv::extractChannel(
*im, cv::Mat(rh, rw, CV_32FC1, base + cur_c * rh * rw), i);
}
}
......@@ -73,27 +65,22 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) {
*im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_);
if (max_size_ != 0 && !image_shape_.empty()) {
// Padding the image with 0 border
cv::copyMakeBorder(
*im,
*im,
0,
max_size_ - im->rows,
0,
max_size_ - im->cols,
cv::BORDER_CONSTANT,
cv::Scalar(0));
cv::copyMakeBorder(*im,
*im,
0,
max_size_ - im->rows,
0,
max_size_ - im->cols,
cv::BORDER_CONSTANT,
cv::Scalar(0));
}
data->eval_im_size_f_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
resize_scale.first
};
data->scale_factor_f_ = {
resize_scale.first,
resize_scale.second,
resize_scale.first,
resize_scale.second
};
data->eval_im_size_f_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols),
resize_scale.first};
data->scale_factor_f_ = {resize_scale.first,
resize_scale.second,
resize_scale.first,
resize_scale.second};
}
std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
......@@ -132,23 +119,14 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_;
int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_;
cv::copyMakeBorder(
*im,
*im,
0,
nh - rh,
0,
nw - rw,
cv::BORDER_CONSTANT,
cv::Scalar(0));
*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::Scalar(0));
(data->eval_im_size_f_)[0] = static_cast<float>(im->rows);
(data->eval_im_size_f_)[1] = static_cast<float>(im->cols);
}
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "Resize", "Normalize", "PadStride", "Permute"
};
"InitInfo", "Resize", "Normalize", "PadStride", "Permute"};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
for (const auto& name : RUN_ORDER) {
......
......@@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono>
#include <fstream>
#include <iostream>
#include <vector>
#include <chrono>
#include <numeric>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
using namespace std;
......@@ -57,13 +56,14 @@ void PrintBenchmarkLog(std::vector<double> det_time,
std::cout << "---------------- Perf info ---------------------" << std::endl;
std::cout << "Total number of predicted data: " << img_num
<< " and total time spent(s): "
<< std::accumulate(det_time.begin(), det_time.end(), 0) << std::endl;
<< std::accumulate(det_time.begin(), det_time.end(), 0)
<< std::endl;
std::cout << "preproce_time(ms): " << det_time[0] / img_num
<< ", inference_time(ms): " << det_time[1] / img_num
<< ", postprocess_time(ms): " << det_time[2] << std::endl;
}
std::vector<std::string> LoadLabels(const std::string &path) {
std::vector<std::string> LoadLabels(const std::string& path) {
std::ifstream file;
std::vector<std::string> labels;
file.open(path);
......@@ -96,18 +96,17 @@ std::vector<std::string> ReadDict(std::string path) {
return m_vec;
}
std::vector<std::string> split(const std::string &str,
const std::string &delim) {
std::vector<std::string> split(const std::string& str,
const std::string& delim) {
std::vector<std::string> res;
if ("" == str)
return res;
char *strs = new char[str.length() + 1];
if ("" == str) return res;
char* strs = new char[str.length() + 1];
std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1];
char* d = new char[delim.length() + 1];
std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d);
char* p = std::strtok(strs, d);
while (p) {
string s = p;
res.push_back(s);
......@@ -128,7 +127,7 @@ std::map<std::string, std::string> LoadConfigTxt(std::string config_path) {
return dict;
}
void PrintConfig(const std::map<std::string, std::string> &config) {
void PrintConfig(const std::map<std::string, std::string>& config) {
std::cout << "=======PaddleDetection lite demo config======" << std::endl;
for (auto iter = config.begin(); iter != config.end(); iter++) {
std::cout << iter->first << " : " << iter->second << std::endl;
......@@ -136,7 +135,6 @@ void PrintConfig(const std::map<std::string, std::string> &config) {
std::cout << "===End of PaddleDetection lite demo config===" << std::endl;
}
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void neon_mean_scale(const float* din,
float* dout,
......@@ -182,11 +180,11 @@ void neon_mean_scale(const float* din,
}
std::vector<Object> visualize_result(
const float* data,
int count,
float thresh,
cv::Mat& image,
const std::vector<std::string> &class_names) {
const float* data,
int count,
float thresh,
cv::Mat& image,
const std::vector<std::string>& class_names) {
if (data == nullptr) {
std::cerr << "[ERROR] data can not be nullptr\n";
exit(1);
......@@ -258,54 +256,59 @@ std::shared_ptr<PaddlePredictor> LoadModel(std::string model_file,
}
ImageBlob prepare_imgdata(const cv::Mat& img,
std::map<std::string,
std::string> config) {
std::map<std::string, std::string> config) {
ImageBlob img_data;
std::vector<int> target_size_;
std::vector<std::string> size_str = split(config.at("Resize"), ",");
transform(size_str.begin(), size_str.end(), back_inserter(target_size_),
[](std::string const& s){return stoi(s);});
transform(size_str.begin(),
size_str.end(),
back_inserter(target_size_),
[](std::string const& s) { return stoi(s); });
int width = target_size_[0];
int height = target_size_[1];
img_data.im_shape_ = {
static_cast<int>(target_size_[0]),
static_cast<int>(target_size_[1])
};
img_data.im_shape_ = {static_cast<int>(target_size_[0]),
static_cast<int>(target_size_[1])};
std::vector<float> mean_;
std::vector<float> scale_;
std::vector<std::string> mean_str = split(config.at("mean"), ",");
std::vector<std::string> std_str = split(config.at("std"), ",");
transform(mean_str.begin(), mean_str.end(), back_inserter(mean_),
[](std::string const& s){return stof(s);});
transform(std_str.begin(), std_str.end(), back_inserter(scale_),
[](std::string const& s){return stof(s);});
transform(mean_str.begin(),
mean_str.end(),
back_inserter(mean_),
[](std::string const& s) { return stof(s); });
transform(std_str.begin(),
std_str.end(),
back_inserter(scale_),
[](std::string const& s) { return stof(s); });
img_data.mean_ = mean_;
img_data.scale_ = scale_;
return img_data;
}
void preprocess(const cv::Mat& img, const ImageBlob img_data, float* data) {
cv::Mat rgb_img;
cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB);
cv::resize(
rgb_img, rgb_img, cv::Size(img_data.im_shape_[0],img_data.im_shape_[1]),
0.f, 0.f, cv::INTER_CUBIC);
cv::resize(rgb_img,
rgb_img,
cv::Size(img_data.im_shape_[0], img_data.im_shape_[1]),
0.f,
0.f,
cv::INTER_CUBIC);
cv::Mat imgf;
rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f);
const float* dimg = reinterpret_cast<const float*>(imgf.data);
neon_mean_scale(
dimg, data, int(img_data.im_shape_[0] * img_data.im_shape_[1]),
img_data.mean_, img_data.scale_);
neon_mean_scale(dimg,
data,
int(img_data.im_shape_[0] * img_data.im_shape_[1]),
img_data.mean_,
img_data.scale_);
}
void RunModel(std::map<std::string, std::string> config,
std::string img_path,
const int repeats,
std::vector<double>* times) {
std::string model_file = config.at("model_file");
std::string label_path = config.at("label_path");
// Load Labels
......@@ -334,14 +337,12 @@ void RunModel(std::map<std::string, std::string> config,
// 2. Run predictor
// warm up
for (int i = 0; i < repeats / 2; i++)
{
for (int i = 0; i < repeats / 2; i++) {
predictor->Run();
}
auto inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++)
{
for (int i = 0; i < repeats; i++) {
predictor->Run();
}
auto inference_end = std::chrono::steady_clock::now();
......
......@@ -530,7 +530,7 @@ def predict_video(detector, camera_id):
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
......@@ -660,6 +660,8 @@ if __name__ == '__main__':
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (FLAGS.enable_mkldnn==False and FLAGS.enable_mkldnn_bfloat16==True),"To turn on mkldnn_bfloat, please set both enable_mkldnn and enable_mkldnn_bfloat16 True"
assert not (
FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
), "To turn on mkldnn_bfloat, please set both enable_mkldnn and enable_mkldnn_bfloat16 True"
main()
......@@ -18,7 +18,7 @@ namespace operators {
using Tensor = framework::Tensor;
class BottomPoolOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -27,7 +27,7 @@ public:
ctx->ShareDim("X", /*->*/ "Output");
}
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
......@@ -36,10 +36,9 @@ protected:
};
class BottomPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X",
"Input with shape (batch, C, H, W)");
AddInput("X", "Input with shape (batch, C, H, W)");
AddOutput("MaxMap", "Max map with index of maximum value of input");
AddOutput("Output", "output with same shape as input(X)");
AddComment(
......@@ -52,10 +51,10 @@ The output has the same shape with input.
};
class BottomPoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
......@@ -75,10 +74,10 @@ protected:
template <typename T>
class BottomPoolGradDescMaker : public framework::SingleGradOpMaker<T> {
public:
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("bottom_pool_grad");
op->SetInput("X", this->Input("X"));
......
......@@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/memory/memory.h"
#include <vector>
#include "paddle/fluid/platform/cuda_primitives.h"
#include "util.cu.h"
namespace paddle {
......@@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) {
template <typename T>
class BottomPoolOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *max_map = ctx.Output<Tensor>("MaxMap");
auto *output = ctx.Output<Tensor>("Output");
auto *x_data = x->data<T>();
auto* x = ctx.Input<Tensor>("X");
auto* max_map = ctx.Output<Tensor>("MaxMap");
auto* output = ctx.Output<Tensor>("Output");
auto* x_data = x->data<T>();
auto x_dims = x->dims();
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
......@@ -47,22 +47,31 @@ public:
int num = x->numel();
auto& dev_ctx = ctx.cuda_device_context();
int *max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T *output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
int* max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T* output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int blocks = NumBlocks(num / height);
auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T));
T* max_val_data = reinterpret_cast<T*>(max_val_ptr->ptr());
auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int));
int* max_ind_data = reinterpret_cast<int*>(max_ind_ptr->ptr());
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), NC_num, height, width, 2, false, max_val_data, max_ind_data, max_map_data);
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
NC_num,
height,
width,
2,
false,
max_val_data,
max_ind_data,
max_map_data);
blocks = NumBlocks(num);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), max_map_data, NC_num, height, width, 2, output_data);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
x->data<T>(), max_map_data, NC_num, height, width, 2, output_data);
}
};
......@@ -75,20 +84,28 @@ class BottomPoolGradOpCUDAKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* in_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto x_dims = x->dims();
auto& dev_ctx = ctx.cuda_device_context();
T* in_grad_data = in_grad->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int grad_num = in_grad->numel();
int blocks = NumBlocks(grad_num);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_grad->data<T>(), max_map->data<int>(), NC_num, height, width, 2, in_grad_data);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_grad->data<T>(),
max_map->data<int>(),
NC_num,
height,
width,
2,
in_grad_data);
}
};
......
......@@ -18,7 +18,7 @@ namespace operators {
using Tensor = framework::Tensor;
class LeftPoolOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -27,7 +27,7 @@ public:
ctx->ShareDim("X", /*->*/ "Output");
}
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
......@@ -36,10 +36,9 @@ protected:
};
class LeftPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X",
"Input with shape (batch, C, H, W)");
AddInput("X", "Input with shape (batch, C, H, W)");
AddOutput("MaxMap", "Max map with index of maximum value of input");
AddOutput("Output", "output with same shape as input(X)");
AddComment(
......@@ -52,10 +51,10 @@ The output has the same shape with input.
};
class LeftPoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
......@@ -75,10 +74,10 @@ protected:
template <typename T>
class LeftPoolGradDescMaker : public framework::SingleGradOpMaker<T> {
public:
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("left_pool_grad");
op->SetInput("X", this->Input("X"));
......
......@@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/memory/memory.h"
#include <vector>
#include "paddle/fluid/platform/cuda_primitives.h"
#include "util.cu.h"
namespace paddle {
......@@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) {
template <typename T>
class LeftPoolOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *max_map = ctx.Output<Tensor>("MaxMap");
auto *output = ctx.Output<Tensor>("Output");
auto *x_data = x->data<T>();
auto* x = ctx.Input<Tensor>("X");
auto* max_map = ctx.Output<Tensor>("MaxMap");
auto* output = ctx.Output<Tensor>("Output");
auto* x_data = x->data<T>();
auto x_dims = x->dims();
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
......@@ -47,10 +47,10 @@ public:
int num = x->numel();
auto& dev_ctx = ctx.cuda_device_context();
int *max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T *output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
int* max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T* output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int blocks = NumBlocks(num / width);
......@@ -59,11 +59,19 @@ public:
auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int));
int* max_ind_data = reinterpret_cast<int*>(max_ind_ptr->ptr());
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), NC_num, height, width, 3, true, max_val_data, max_ind_data, max_map_data);
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
NC_num,
height,
width,
3,
true,
max_val_data,
max_ind_data,
max_map_data);
blocks = NumBlocks(num);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), max_map_data, NC_num, height, width, 3, output_data);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
x->data<T>(), max_map_data, NC_num, height, width, 3, output_data);
}
};
......@@ -76,24 +84,31 @@ class LeftPoolGradOpCUDAKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* in_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto x_dims = x->dims();
auto& dev_ctx = ctx.cuda_device_context();
T* in_grad_data = in_grad->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
int height = x_dims[2];
int width = x_dims[3];
int grad_num = in_grad->numel();
int blocks = NumBlocks(grad_num);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_grad->data<T>(), max_map->data<int>(), NC_num, height, width, 3, in_grad_data);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_grad->data<T>(),
max_map->data<int>(),
NC_num,
height,
width,
3,
in_grad_data);
}
};
} // namespace operators
} // namespace paddle
......
......@@ -18,7 +18,7 @@ namespace operators {
using Tensor = framework::Tensor;
class RightPoolOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -27,7 +27,7 @@ public:
ctx->ShareDim("X", /*->*/ "Output");
}
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
......@@ -36,11 +36,10 @@ protected:
};
class RightPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X",
"Input with shape (batch, C, H, W)");
AddOutput("MaxMap", "Max map with index of maximum value of input");
AddInput("X", "Input with shape (batch, C, H, W)");
AddOutput("MaxMap", "Max map with index of maximum value of input");
AddOutput("Output", "output with same shape as input(X)");
AddComment(
R"Doc(
......@@ -52,10 +51,10 @@ The output has the same shape with input.
};
class RightPoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
......@@ -75,10 +74,10 @@ protected:
template <typename T>
class RightPoolGradDescMaker : public framework::SingleGradOpMaker<T> {
public:
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("right_pool_grad");
op->SetInput("X", this->Input("X"));
......
......@@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/memory/memory.h"
#include <vector>
#include "paddle/fluid/platform/cuda_primitives.h"
#include "util.cu.h"
namespace paddle {
......@@ -32,14 +32,14 @@ static inline int NumBlocks(const int N) {
template <typename T>
class RightPoolOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *max_map = ctx.Output<Tensor>("MaxMap");
auto *output = ctx.Output<Tensor>("Output");
auto *x_data = x->data<T>();
auto* x = ctx.Input<Tensor>("X");
auto* max_map = ctx.Output<Tensor>("MaxMap");
auto* output = ctx.Output<Tensor>("Output");
auto* x_data = x->data<T>();
auto x_dims = x->dims();
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
......@@ -47,23 +47,31 @@ public:
int num = x->numel();
auto& dev_ctx = ctx.cuda_device_context();
int *max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T *output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
int* max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T* output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int blocks = NumBlocks(num / width);
auto max_val_ptr = memory::Alloc(gpu_place, num / width * sizeof(T));
T* max_val_data = reinterpret_cast<T*>(max_val_ptr->ptr());
auto max_ind_ptr = memory::Alloc(gpu_place, num / width * sizeof(int));
int* max_ind_data = reinterpret_cast<int*>(max_ind_ptr->ptr());
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), NC_num, height, width, 3, false, max_val_data, max_ind_data, max_map_data);
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
NC_num,
height,
width,
3,
false,
max_val_data,
max_ind_data,
max_map_data);
blocks = NumBlocks(num);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), max_map_data, NC_num, height, width, 3, output_data);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
x->data<T>(), max_map_data, NC_num, height, width, 3, output_data);
}
};
......@@ -76,20 +84,28 @@ class RightPoolGradOpCUDAKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* in_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto x_dims = x->dims();
auto& dev_ctx = ctx.cuda_device_context();
T* in_grad_data = in_grad->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int grad_num = in_grad->numel();
int blocks = NumBlocks(grad_num);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_grad->data<T>(), max_map->data<int>(), NC_num, height, width, 3, in_grad_data);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_grad->data<T>(),
max_map->data<int>(),
NC_num,
height,
width,
3,
in_grad_data);
}
};
......
......@@ -18,7 +18,7 @@ namespace operators {
using Tensor = framework::Tensor;
class TopPoolOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -27,19 +27,18 @@ public:
ctx->ShareDim("X", /*->*/ "Output");
}
protected:
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class TopPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X",
"Input with shape (batch, C, H, W)");
AddInput("X", "Input with shape (batch, C, H, W)");
AddOutput("MaxMap", "Max map with index of maximum value of input");
AddOutput("Output", "Output with same shape as input(X)");
AddComment(
......@@ -52,16 +51,16 @@ The output has the same shape with input.
};
class TopPoolOpGrad : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("MaxMap"), "Input(MaxMap) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
"Input(Output@GRAD) should not be null");
auto out_grad_name = framework::GradVarName("Output");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
}
......
......@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/memory/memory.h"
#include <vector>
#include "paddle/fluid/platform/cuda_primitives.h"
#include "util.cu.h"
namespace paddle {
......@@ -33,14 +33,14 @@ static inline int NumBlocks(const int N) {
template <typename T>
class TopPoolOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *max_map = ctx.Output<Tensor>("MaxMap");
auto *output = ctx.Output<Tensor>("Output");
auto *x_data = x->data<T>();
auto* x = ctx.Input<Tensor>("X");
auto* max_map = ctx.Output<Tensor>("MaxMap");
auto* output = ctx.Output<Tensor>("Output");
auto* x_data = x->data<T>();
auto x_dims = x->dims();
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
......@@ -48,22 +48,31 @@ public:
int num = x->numel();
auto& dev_ctx = ctx.cuda_device_context();
int *max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T *output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
int* max_map_data = max_map->mutable_data<int>(x_dims, dev_ctx.GetPlace());
T* output_data = output->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int blocks = NumBlocks(num / height);
auto max_val_ptr = memory::Alloc(gpu_place, num / height * sizeof(T));
T* max_val_data = reinterpret_cast<T*>(max_val_ptr->ptr());
auto max_ind_ptr = memory::Alloc(gpu_place, num / height * sizeof(int));
int* max_ind_data = reinterpret_cast<int*>(max_ind_ptr->ptr());
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), NC_num, height, width, 2, true, max_val_data, max_ind_data, max_map_data);
GetMaxInfo<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
NC_num,
height,
width,
2,
true,
max_val_data,
max_ind_data,
max_map_data);
blocks = NumBlocks(num);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(), max_map_data, NC_num, height, width, 2, output_data);
ScatterAddFw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
x->data<T>(), max_map_data, NC_num, height, width, 2, output_data);
}
};
......@@ -79,16 +88,24 @@ class TopPoolGradOpCUDAKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context();
T* in_grad_data = in_grad->mutable_data<T>(x_dims, dev_ctx.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
int threads = kNumCUDAThreads;
int NC_num = x_dims[0] * x_dims[1];
int height = x_dims[2];
int width = x_dims[3];
int grad_num = in_grad->numel();
int blocks = NumBlocks(grad_num);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_grad->data<T>(), max_map->data<int>(), NC_num, height, width, 2, in_grad_data);
FillConstant<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
in_grad_data, 0, grad_num);
ScatterAddBw<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
out_grad->data<T>(),
max_map->data<int>(),
NC_num,
height,
width,
2,
in_grad_data);
}
};
......
......@@ -11,10 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/memory/memory.h"
#include <vector>
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
......@@ -27,15 +27,18 @@ using framework::Tensor;
template <typename T>
__global__ void FillConstant(T* x, int num, int fill_num) {
CUDA_1D_KERNEL_LOOP(i, fill_num) {
x[i] = static_cast<T>(num);
}
CUDA_1D_KERNEL_LOOP(i, fill_num) { x[i] = static_cast<T>(num); }
}
template <typename T>
__global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int W,
const int axis, const int start, const int end,
T* output) {
__global__ void SliceOnAxis(const T* x,
const int NC_num,
const int H,
const int W,
const int axis,
const int start,
const int end,
T* output) {
int HW_num = H * W;
int length = axis == 2 ? W : H;
int sliced_len = end - start;
......@@ -44,22 +47,28 @@ __global__ void SliceOnAxis(const T* x, const int NC_num, const int H, const int
CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
int NC_id = i / cur_HW_num;
int HW_id = i % cur_HW_num;
if (axis == 2){
if (axis == 2) {
output[i] = x[NC_id * HW_num + start * W + HW_id];
} else if (axis == 3) {
int col = HW_id % sliced_len;
int row = HW_id / sliced_len;
output[i] = x[NC_id * HW_num + row * W + start + col];
}
}
}
}
template <typename T>
__global__ void MaxOut(const T* input, const int next_ind, const int NC_num,
const int H, const int W, const int axis,
const int start, const int end, T* output) {
__global__ void MaxOut(const T* input,
const int next_ind,
const int NC_num,
const int H,
const int W,
const int axis,
const int start,
const int end,
T* output) {
int HW_num = H * W;
int length = axis == 2 ? W : H;
int length = axis == 2 ? W : H;
T cur = static_cast<T>(0.);
T next = static_cast<T>(0.);
T max_v = static_cast<T>(0.);
......@@ -69,11 +78,11 @@ __global__ void MaxOut(const T* input, const int next_ind, const int NC_num,
CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
int NC_id = i / cur_HW_num;
int HW_id = i % cur_HW_num;
if (axis == 2){
if (axis == 2) {
cur = input[NC_id * HW_num + start * W + HW_id];
next = input[NC_id * HW_num + next_ind * W + HW_id];
max_v = cur > next ? cur : next;
max_v = cur > next ? cur : next;
output[NC_id * HW_num + start * W + HW_id] = max_v;
} else if (axis == 3) {
int col = HW_id % sliced_len;
......@@ -88,11 +97,16 @@ __global__ void MaxOut(const T* input, const int next_ind, const int NC_num,
}
template <typename T>
__global__ void UpdateMaxInfo(const T* input, const int NC_num,
const int H, const int W, const int axis,
const int index, T* max_val, int* max_ind) {
__global__ void UpdateMaxInfo(const T* input,
const int NC_num,
const int H,
const int W,
const int axis,
const int index,
T* max_val,
int* max_ind) {
int length = axis == 2 ? W : H;
int HW_num = H * W;
int HW_num = H * W;
T val = static_cast<T>(0.);
CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
int NC_id = i / length;
......@@ -111,82 +125,104 @@ __global__ void UpdateMaxInfo(const T* input, const int NC_num,
}
template <typename T>
__global__ void ScatterAddOnAxis(const T* input, const int start, const int* max_ind, const int NC_num, const int H, const int W, const int axis, T* output) {
__global__ void ScatterAddOnAxis(const T* input,
const int start,
const int* max_ind,
const int NC_num,
const int H,
const int W,
const int axis,
T* output) {
int length = axis == 2 ? W : H;
int HW_num = H * W;
CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
int NC_id = i / length;
int length_id = i % length;
int id_ = max_ind[i];
if (axis == 2) {
platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id, input[NC_id * HW_num + start * W + length_id]);
//output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num + start * W + length_id];
platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id,
input[NC_id * HW_num + start * W + length_id]);
// output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num +
// start * W + length_id];
} else if (axis == 3) {
platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_, input[NC_id * HW_num + length_id * W + start]);
//output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num + length_id * W + start];
platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_,
input[NC_id * HW_num + length_id * W + start]);
// output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num +
// length_id * W + start];
}
__syncthreads();
}
}
template <typename T>
__global__ void GetMaxInfo(const T* input, const int NC_num,
const int H, const int W, const int axis,
const bool reverse, T* max_val, int* max_ind,
__global__ void GetMaxInfo(const T* input,
const int NC_num,
const int H,
const int W,
const int axis,
const bool reverse,
T* max_val,
int* max_ind,
int* max_map) {
int start = 0;
int end = axis == 2 ? H: W;
int s = reverse ? end-1 : start;
int e = reverse ? start-1 : end;
int step = reverse ? -1 : 1;
int len = axis == 2 ? W : H;
int loc = 0;
T val = static_cast<T>(0.);
for (int i = s; ; ) {
if (i == s) {
CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
int NC_id = j / len;
int len_id = j % len;
if (axis == 2) {
loc = NC_id * H * W + i * W + len_id;
} else if (axis == 3){
loc = NC_id * H * W + len_id * W + i;
}
max_ind[j] = i;
max_map[loc] = max_ind[j];
max_val[j] = input[loc];
__syncthreads();
}
} else {
CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
int NC_id = j / len;
int len_id = j % len;
if (axis == 2) {
loc = NC_id * H * W + i * W + len_id;
} else if (axis == 3){
loc = NC_id * H * W + len_id * W + i;
}
val = input[loc];
T max_v = max_val[j];
if (val > max_v) {
max_val[j] = val;
max_map[loc] = i;
max_ind[j] = i;
} else {
max_map[loc] = max_ind[j];
}
__syncthreads();
}
}
i += step;
if (s < e && i >= e) break;
if (s > e && i <= e) break;
}
int start = 0;
int end = axis == 2 ? H : W;
int s = reverse ? end - 1 : start;
int e = reverse ? start - 1 : end;
int step = reverse ? -1 : 1;
int len = axis == 2 ? W : H;
int loc = 0;
T val = static_cast<T>(0.);
for (int i = s;;) {
if (i == s) {
CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
int NC_id = j / len;
int len_id = j % len;
if (axis == 2) {
loc = NC_id * H * W + i * W + len_id;
} else if (axis == 3) {
loc = NC_id * H * W + len_id * W + i;
}
max_ind[j] = i;
max_map[loc] = max_ind[j];
max_val[j] = input[loc];
__syncthreads();
}
} else {
CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
int NC_id = j / len;
int len_id = j % len;
if (axis == 2) {
loc = NC_id * H * W + i * W + len_id;
} else if (axis == 3) {
loc = NC_id * H * W + len_id * W + i;
}
val = input[loc];
T max_v = max_val[j];
if (val > max_v) {
max_val[j] = val;
max_map[loc] = i;
max_ind[j] = i;
} else {
max_map[loc] = max_ind[j];
}
__syncthreads();
}
}
i += step;
if (s < e && i >= e) break;
if (s > e && i <= e) break;
}
}
template <typename T>
__global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){
__global__ void ScatterAddFw(const T* input,
const int* max_map,
const int NC_num,
const int H,
const int W,
const int axis,
T* output) {
CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
int loc = max_map[i];
int NC_id = i / (H * W);
......@@ -202,7 +238,13 @@ __global__ void ScatterAddFw(const T* input, const int* max_map, const int NC_nu
}
template <typename T>
__global__ void ScatterAddBw(const T* input, const int* max_map, const int NC_num, const int H, const int W, const int axis, T* output){
__global__ void ScatterAddBw(const T* input,
const int* max_map,
const int NC_num,
const int H,
const int W,
const int axis,
T* output) {
CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
int loc = max_map[i];
int NC_id = i / (H * W);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册