提交 e14272bb 编写于 作者: G gaoyuan

update accoding to the code review

上级 c3e89f30
......@@ -18,79 +18,85 @@ namespace operators {
template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, int row,
int col, T* output) {
const T* target_box_data, const int row,
const int col, const int len,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
const int col_idx = idx % col;
T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
T prior_box_center_y =
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
prior_box_data[col_idx * len + 1]) /
2;
T target_box_center_x =
(target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2;
T target_box_center_y =
(target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) /
(target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
2;
T target_box_center_y = (target_box_data[row_idx * len + 3] +
target_box_data[row_idx * len + 1]) /
2;
T target_box_width =
target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4];
target_box_data[row_idx * len + 2] - target_box_data[row_idx * len];
T target_box_height =
target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1];
target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1];
output[idx * 4] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * 4];
output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height /
prior_box_var_data[col_idx * 4 + 1];
output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) /
prior_box_var_data[col_idx * 4 + 2];
output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) /
prior_box_var_data[col_idx * 4 + 3];
output[idx * len] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * len];
output[idx * len + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height /
prior_box_var_data[col_idx * len + 1];
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) /
prior_box_var_data[col_idx * len + 2];
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) /
prior_box_var_data[col_idx * len + 3];
}
}
template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, int row,
int col, T* output) {
const T* target_box_data, const int row,
const int col, const int len,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
const int col_idx = idx % col;
T prior_box_width =
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
T prior_box_height =
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1];
T prior_box_center_x =
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
T prior_box_center_y =
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
prior_box_data[col_idx * len + 1]) /
2;
T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] *
target_box_data[row_idx * 4 + 2]) *
T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
target_box_data[row_idx * len + 2]) *
prior_box_width;
T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] *
target_box_data[row_idx * 4 + 3]) *
T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_data[row_idx * len + 3]) *
prior_box_height;
T target_box_center_x = prior_box_var_data[col_idx * 4] *
target_box_data[row_idx * 4] * prior_box_width +
T target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_data[row_idx * len] *
prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] *
target_box_data[row_idx * 4 + 1] *
T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_data[row_idx * len + 1] *
prior_box_height +
prior_box_center_y;
output[idx * 4] = target_box_center_x - target_box_width / 2;
output[idx * 4 + 1] = target_box_center_y - target_box_height / 2;
output[idx * 4 + 2] = target_box_center_x + target_box_width / 2;
output[idx * 4 + 3] = target_box_center_y + target_box_height / 2;
output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * len + 1] = target_box_center_y - target_box_height / 2;
output[idx * len + 2] = target_box_center_x + target_box_width / 2;
output[idx * len + 3] = target_box_center_y + target_box_height / 2;
}
}
......@@ -111,6 +117,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
}
auto row = target_box->dims()[0];
auto col = prior_box->dims()[0];
auto len = prior_box->dims()[1];
int block = 512;
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
......@@ -119,17 +126,17 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
const T* prior_box_var_data = prior_box_var->data<T>();
const T* target_box_data = target_box->data<T>();
output_box->mutable_data<T>({row, col, 4}, context.GetPlace());
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>();
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col,
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col,
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册