提交 cb2dca53 编写于 作者: D dengkaipeng

fix cuda kernel error

上级 04b8b9e9
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -22,11 +23,12 @@ using Tensor = framework::Tensor;
template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
std::vector<int> anchors, const int h, const in w,
const int* anchors, const int h, const int w,
const int an_num, const int class_num,
const int box_num, const int input_size) {
const int box_num, int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
......@@ -47,10 +49,10 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
Box<T> pred = GetYoloBox<T>(input, anchors, l, k, j, h, input_size, box_idx,
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, pred, box_idx);
CalcDetectionBox<T>(boxes, box, box_idx);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
......@@ -64,7 +66,7 @@ template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* input = ctx.Input<Tensor>("X");
auto* img_size = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
......@@ -81,23 +83,35 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
Tensor anchors_t, cpu_anchors_t;
auto cpu_anchors_data = cpu_anchors_t.mutable_data<int>({an_num*2}, platform::CPUPlace());
std::copy(anchors.begin(), anchors.end(), cpu_anchors_data);
TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t);
auto anchors_data = anchors_t.data<int>();
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
const int* imgsize_data = img_size->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
int grid_dim = (n * box_num + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, h, w, an_num, class_num, box_num, input_size);
}
}; // namespace operators
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(density_prior_box,
ops::DensityPriorBoxOpCUDAKernel<float>,
ops::DensityPriorBoxOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(yolo_box,
ops::YoloBoxOpCUDAKernel<float>,
ops::YoloBoxOpCUDAKernel<double>);
......@@ -13,35 +13,30 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Box {
T x, y, w, h;
};
template <typename T>
static inline T sigmoid(T x) {
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
HOSTDEVICE inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
Box<T> b;
b.x = (i + sigmoid<T>(x[index])) * img_width / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height /
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height /
input_size;
return b;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
......@@ -51,12 +46,12 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
}
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, Box<T> pred,
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box,
const int box_idx) {
boxes[box_idx] = pred.x - pred.w / 2;
boxes[box_idx + 1] = pred.y - pred.h / 2;
boxes[box_idx + 2] = pred.x + pred.w / 2;
boxes[box_idx + 3] = pred.y + pred.h / 2;
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
}
template <typename T>
......@@ -92,6 +87,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
int anchors_[anchors.size()];
std::copy(anchors.begin(), anchors.end(), anchors_);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
......@@ -100,6 +98,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
T box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
......@@ -116,11 +115,10 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
Box<T> pred =
GetYoloBox<T>(input_data, anchors, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, pred, box_idx);
CalcDetectionBox<T>(boxes_data, box, box_idx);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
......
......@@ -93,16 +93,17 @@ class TestYoloBoxOp(OpTest):
}
def test_check_output(self):
self.check_output()
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 3
self.batch_size = 1
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 5, 5)
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 2, 2)
self.imgsize_shape = (self.batch_size, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册