未验证 提交 2a6a259d 编写于 作者: W Wilber 提交者: GitHub

fix yolobox_cuda_test (#2208)

fix yolobox_cuda test precision error
上级 8591aaec
......@@ -89,7 +89,7 @@ inline static void calc_label_score(float* scores,
template <typename T>
static void YoloBoxRef(const T* input,
const T* imgsize,
const int* imgsize,
T* boxes,
T* scores,
const float conf_thresh,
......@@ -106,8 +106,8 @@ static void YoloBoxRef(const T* input,
float box[4];
for (int i = 0; i < n; i++) {
int img_height = static_cast<int>(imgsize[2 * i]);
int img_width = static_cast<int>(imgsize[2 * i + 1]);
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
......@@ -184,12 +184,12 @@ TEST(yolo_box, normal) {
auto* scores_data = scores.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* sz_cpu_data = sz_cpu.mutable_data<float>();
int* sz_cpu_data = sz_cpu.mutable_data<int>();
float* boxes_cpu_data = boxes_cpu.mutable_data<float>();
float* scores_cpu_data = scores_cpu.mutable_data<float>();
float* x_ref_data = x_ref.mutable_data<float>();
float* sz_ref_data = sz_ref.mutable_data<float>();
int* sz_ref_data = sz_ref.mutable_data<int>();
float* boxes_ref_data = boxes_ref.mutable_data<float>();
float* scores_ref_data = scores_ref.mutable_data<float>();
......@@ -203,7 +203,7 @@ TEST(yolo_box, normal) {
sz_ref_data[1] = 32;
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
sz.Assign<float, lite::DDim, TARGET(kCUDA)>(sz_cpu_data, sz_cpu.dims());
sz.Assign<int, lite::DDim, TARGET(kCUDA)>(sz_cpu_data, sz_cpu.dims());
param.X = &x;
param.ImgSize = &sz;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册