提交 e4e37640 编写于 作者: D dengkaipeng

use memory Copy. test=develop

上级 626fb859
...@@ -74,9 +74,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,9 +74,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"The input tensor of YoloBox operator, " "The input tensor of YoloBox operator is a 4-D tensor with "
"This is a 4-D tensor with shape of [N, C, H, W]. " "shape of [N, C, H, W]. The second dimension(C) stores "
"H and W should be same, and the second dimension(C) stores "
"box locations, confidence score and classification one-hot " "box locations, confidence score and classification one-hot "
"keys of each anchor box. Generally, X should be the output " "keys of each anchor box. Generally, X should be the output "
"of YOLOv3 network."); "of YOLOv3 network.");
...@@ -91,10 +90,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,10 +90,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"batch num, M is output box number, and the 3rd dimension " "batch num, M is output box number, and the 3rd dimension "
"stores [xmin, ymin, xmax, ymax] coordinates of boxes."); "stores [xmin, ymin, xmax, ymax] coordinates of boxes.");
AddOutput("Scores", AddOutput("Scores",
"The output tensor ofdetection boxes scores of YoloBox " "The output tensor of detection boxes scores of YoloBox "
"operator, This is a 3-D tensor with shape of [N, M, C], " "operator, This is a 3-D tensor with shape of "
"N is the batch num, M is output box number, C is the " "[N, M, :attr:`class_num`], N is the batch num, M is "
"class number."); "output box number.");
AddAttr<int>("class_num", "The number of classes to predict."); AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors", AddAttr<std::vector<int>>("anchors",
...@@ -112,7 +111,7 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,7 +111,7 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"be ignored.") "be ignored.")
.SetDefault(0.01); .SetDefault(0.01);
AddComment(R"DOC( AddComment(R"DOC(
This operator generate YOLO detection boxes from output of YOLOv3 network. This operator generates YOLO detection boxes from output of YOLOv3 network.
The output of previous network is in shape [N, C, H, W], while H and W The output of previous network is in shape [N, C, H, W], while H and W
should be the same, H and W specify the grid size, each grid point predict should be the same, H and W specify the grid size, each grid point predict
...@@ -150,6 +149,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -150,6 +149,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
:attr:`conf_thresh` should be ignored, and box final scores is the product of :attr:`conf_thresh` should be ignored, and box final scores is the product of
confidence scores and classification scores. confidence scores and classification scores.
$$
score_{pred} = score_{conf} * score_{class}
$$
)DOC"); )DOC");
} }
}; };
......
...@@ -83,12 +83,22 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -83,12 +83,22 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const int an_num = anchors.size() / 2; const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h; int input_size = downsample_ratio * h;
Tensor anchors_t, cpu_anchors_t; /* Tensor anchors_t, cpu_anchors_t; */
auto cpu_anchors_data = /* auto cpu_anchors_data = */
cpu_anchors_t.mutable_data<int>({an_num * 2}, platform::CPUPlace()); /* cpu_anchors_t.mutable_data<int>({an_num * 2}, platform::CPUPlace()); */
std::copy(anchors.begin(), anchors.end(), cpu_anchors_data); /* std::copy(anchors.begin(), anchors.end(), cpu_anchors_data); */
TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t); /* TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t); */
auto anchors_data = anchors_t.data<int>(); /* auto anchors_data = anchors_t.data<int>(); */
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
const auto cplace = platform::CPUPlace();
memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
dev_ctx.stream());
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const int* imgsize_data = img_size->data<int>(); const int* imgsize_data = img_size->data<int>();
...@@ -96,7 +106,6 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -96,7 +106,6 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
T* scores_data = T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace()); scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero; math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
set_zero(dev_ctx, boxes, static_cast<T>(0)); set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0)); set_zero(dev_ctx, scores, static_cast<T>(0));
......
...@@ -632,8 +632,8 @@ def yolo_box(x, ...@@ -632,8 +632,8 @@ def yolo_box(x,
Returns: Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
and a 3-D tensor with shape [N, M, C], the classification scores and a 3-D tensor with shape [N, M, :attr:`class_num`], the classification
of boxes. scores of boxes.
Raises: Raises:
TypeError: Input x of yolov_box must be Variable TypeError: Input x of yolov_box must be Variable
...@@ -647,7 +647,7 @@ def yolo_box(x, ...@@ -647,7 +647,7 @@ def yolo_box(x,
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23] anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolov3_loss(x=x, class_num=80, anchors=anchors, loss = fluid.layers.yolo_box(x=x, class_num=80, anchors=anchors,
conf_thresh=0.01, downsample_ratio=32) conf_thresh=0.01, downsample_ratio=32)
""" """
helper = LayerHelper('yolo_box', **locals()) helper = LayerHelper('yolo_box', **locals())
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -75,8 +75,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -75,8 +75,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
mask_num = len(anchor_mask) mask_num = len(anchor_mask)
class_num = attrs["class_num"] class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh'] ignore_thresh = attrs['ignore_thresh']
downsample_ratio = attrs['downsample_ratio'] downsample = attrs['downsample']
input_size = downsample_ratio * h input_size = downsample * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32') loss = np.zeros((n)).astype('float32')
...@@ -86,6 +86,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): ...@@ -86,6 +86,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:],
np.ones_like(x[:, :, :, :, 5:]) * 1.0 /
class_num)
mask_anchors = [] mask_anchors = []
for m in anchor_mask: for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
...@@ -172,7 +176,7 @@ class TestYolov3LossOp(OpTest): ...@@ -172,7 +176,7 @@ class TestYolov3LossOp(OpTest):
"anchor_mask": self.anchor_mask, "anchor_mask": self.anchor_mask,
"class_num": self.class_num, "class_num": self.class_num,
"ignore_thresh": self.ignore_thresh, "ignore_thresh": self.ignore_thresh,
"downsample_ratio": self.downsample_ratio, "downsample": self.downsample,
} }
self.inputs = { self.inputs = {
...@@ -204,7 +208,7 @@ class TestYolov3LossOp(OpTest): ...@@ -204,7 +208,7 @@ class TestYolov3LossOp(OpTest):
self.anchor_mask = [1, 2] self.anchor_mask = [1, 2]
self.class_num = 5 self.class_num = 5
self.ignore_thresh = 0.5 self.ignore_thresh = 0.5
self.downsample_ratio = 32 self.downsample = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4) self.gtbox_shape = (3, 5, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册