提交 e14272bb 编写于 作者: G gaoyuan

update accoding to the code review

上级 c3e89f30
...@@ -18,79 +18,85 @@ namespace operators { ...@@ -18,79 +18,85 @@ namespace operators {
template <typename T> template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data, __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, int row, const T* target_box_data, const int row,
int col, T* output) { const int col, const int len,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int row_idx = idx / col; const int row_idx = idx / col;
const int col_idx = idx % col; const int col_idx = idx % col;
T prior_box_width = T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
T prior_box_height = T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; prior_box_data[col_idx * len + 1]) /
2;
T target_box_center_x = T target_box_center_x =
(target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2; (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
T target_box_center_y =
(target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) /
2; 2;
T target_box_center_y = (target_box_data[row_idx * len + 3] +
target_box_data[row_idx * len + 1]) /
2;
T target_box_width = T target_box_width =
target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4]; target_box_data[row_idx * len + 2] - target_box_data[row_idx * len];
T target_box_height = T target_box_height =
target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1]; target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1];
output[idx * 4] = (target_box_center_x - prior_box_center_x) / output[idx * len] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * 4]; prior_box_width / prior_box_var_data[col_idx * len];
output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) / output[idx * len + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height / prior_box_height /
prior_box_var_data[col_idx * 4 + 1]; prior_box_var_data[col_idx * len + 1];
output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) / output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) /
prior_box_var_data[col_idx * 4 + 2]; prior_box_var_data[col_idx * len + 2];
output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) / output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) /
prior_box_var_data[col_idx * 4 + 3]; prior_box_var_data[col_idx * len + 3];
} }
} }
template <typename T> template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data, __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, int row, const T* target_box_data, const int row,
int col, T* output) { const int col, const int len,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int row_idx = idx / col; const int row_idx = idx / col;
const int col_idx = idx % col; const int col_idx = idx % col;
T prior_box_width = T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
T prior_box_height = T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; prior_box_data[col_idx * len + 1]) /
2;
T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] * T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
target_box_data[row_idx * 4 + 2]) * target_box_data[row_idx * len + 2]) *
prior_box_width; prior_box_width;
T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] * T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_data[row_idx * 4 + 3]) * target_box_data[row_idx * len + 3]) *
prior_box_height; prior_box_height;
T target_box_center_x = prior_box_var_data[col_idx * 4] * T target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_data[row_idx * 4] * prior_box_width + target_box_data[row_idx * len] *
prior_box_width +
prior_box_center_x; prior_box_center_x;
T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] * T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_data[row_idx * 4 + 1] * target_box_data[row_idx * len + 1] *
prior_box_height + prior_box_height +
prior_box_center_y; prior_box_center_y;
output[idx * 4] = target_box_center_x - target_box_width / 2; output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * 4 + 1] = target_box_center_y - target_box_height / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2;
output[idx * 4 + 2] = target_box_center_x + target_box_width / 2; output[idx * len + 2] = target_box_center_x + target_box_width / 2;
output[idx * 4 + 3] = target_box_center_y + target_box_height / 2; output[idx * len + 3] = target_box_center_y + target_box_height / 2;
} }
} }
...@@ -111,6 +117,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -111,6 +117,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
} }
auto row = target_box->dims()[0]; auto row = target_box->dims()[0];
auto col = prior_box->dims()[0]; auto col = prior_box->dims()[0];
auto len = prior_box->dims()[1];
int block = 512; int block = 512;
int grid = (row * col + block - 1) / block; int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context(); auto& device_ctx = context.cuda_device_context();
...@@ -119,17 +126,17 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -119,17 +126,17 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
const T* prior_box_var_data = prior_box_var->data<T>(); const T* prior_box_var_data = prior_box_var->data<T>();
const T* target_box_data = target_box->data<T>(); const T* target_box_data = target_box->data<T>();
output_box->mutable_data<T>({row, col, 4}, context.GetPlace()); output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>(); T* output = output_box->data<T>();
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output); output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output); output);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册