diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index 38b514fe90f218a7171cc13911cafafdff4af2cf..bc563107f8824b834216923f422f903f726be78f 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detection/yolo_box_op.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { @@ -22,11 +23,12 @@ using Tensor = framework::Tensor; template __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, T* scores, const float conf_thresh, - std::vector anchors, const int h, const in w, + const int* anchors, const int h, const int w, const int an_num, const int class_num, - const int box_num, const int input_size) { + const int box_num, int input_size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; + T box[4]; for (; tid < box_num; tid += stride) { int grid_num = h * w; int i = tid / box_num; @@ -47,10 +49,10 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); - Box pred = GetYoloBox(input, anchors, l, k, j, h, input_size, box_idx, + GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, grid_num, img_height, img_width); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; - CalcDetectionBox(boxes, pred, box_idx); + CalcDetectionBox(boxes, box, box_idx); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); @@ -64,7 +66,7 @@ template class YoloBoxOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); + auto* input = ctx.Input("X"); auto* img_size = ctx.Input("ImgSize"); auto* boxes = ctx.Output("Boxes"); auto* scores = ctx.Output("Scores"); @@ -81,23 +83,35 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { const int an_num = anchors.size() / 2; int input_size = downsample_ratio * h; + Tensor anchors_t, cpu_anchors_t; + auto cpu_anchors_data = cpu_anchors_t.mutable_data({an_num*2}, platform::CPUPlace()); + std::copy(anchors.begin(), anchors.end(), cpu_anchors_data); + TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t); + auto anchors_data = anchors_t.data(); + const T* input_data = input->data(); - const int* imgsize_data = imgsize->data(); + const int* imgsize_data = img_size->data(); T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); - memset(boxes_data, 0, boxes->numel() * sizeof(T)); T* scores_data = scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); - memset(scores_data, 0, scores->numel() * sizeof(T)); + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, boxes, static_cast(0)); + set_zero(dev_ctx, scores, static_cast(0)); int grid_dim = (n * box_num + 512 - 1) / 512; grid_dim = grid_dim > 8 ? 8 : grid_dim; + + KeYoloBoxFw<<>>( + input_data, imgsize_data, boxes_data, scores_data, conf_thresh, + anchors_data, h, w, an_num, class_num, box_num, input_size); } -}; // namespace operators +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(density_prior_box, - ops::DensityPriorBoxOpCUDAKernel, - ops::DensityPriorBoxOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(yolo_box, + ops::YoloBoxOpCUDAKernel, + ops::YoloBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 90933e123e098407dfc06124c4f2bc3cee9cb12a..6188c5f32b742e3ff62ec3247e2a0a48149f29e6 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -13,35 +13,30 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct Box { - T x, y, w, h; -}; template -static inline T sigmoid(T x) { +HOSTDEVICE inline T sigmoid(T x) { return 1.0 / (1.0 + std::exp(-x)); } template -HOSTDEVICE inline Box GetYoloBox(const T* x, std::vector anchors, int i, +HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, int j, int an_idx, int grid_size, int input_size, int index, int stride, int img_height, int img_width) { - Box b; - b.x = (i + sigmoid(x[index])) * img_width / grid_size; - b.y = (j + sigmoid(x[index + stride])) * img_height / grid_size; - b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + box[0] = (i + sigmoid(x[index])) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / input_size; - b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height / + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height / input_size; - return b; } HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, @@ -51,12 +46,12 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, } template -HOSTDEVICE inline void CalcDetectionBox(T* boxes, Box pred, +HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx) { - boxes[box_idx] = pred.x - pred.w / 2; - boxes[box_idx + 1] = pred.y - pred.h / 2; - boxes[box_idx + 2] = pred.x + pred.w / 2; - boxes[box_idx + 3] = pred.y + pred.h / 2; + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; } template @@ -92,6 +87,9 @@ class YoloBoxKernel : public framework::OpKernel { const int stride = h * w; const int an_stride = (class_num + 5) * stride; + int anchors_[anchors.size()]; + std::copy(anchors.begin(), anchors.end(), anchors_); + const T* input_data = input->data(); const int* imgsize_data = imgsize->data(); T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); @@ -100,6 +98,7 @@ class YoloBoxKernel : public framework::OpKernel { scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); memset(scores_data, 0, scores->numel() * sizeof(T)); + T box[4]; for (int i = 0; i < n; i++) { int img_height = imgsize_data[2 * i]; int img_width = imgsize_data[2 * i + 1]; @@ -116,11 +115,10 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); - Box pred = - GetYoloBox(input_data, anchors, l, k, j, h, input_size, - box_idx, stride, img_height, img_width); + GetYoloBox(box, input_data, anchors_, l, k, j, h, input_size, + box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; - CalcDetectionBox(boxes_data, pred, box_idx); + CalcDetectionBox(boxes_data, box, box_idx); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 48465c8f68ac0feb72704861c729efbcdc3c8e55..e28c05e3e6de2438e4a67c34c4f087d3c1c5fc15 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -93,16 +93,17 @@ class TestYoloBoxOp(OpTest): } def test_check_output(self): - self.check_output() + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] an_num = int(len(self.anchors) // 2) - self.batch_size = 3 + self.batch_size = 1 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 - self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 5, 5) + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 2, 2) self.imgsize_shape = (self.batch_size, 2)