diff --git a/paddle/operators/box_coder_op.cu b/paddle/operators/box_coder_op.cu index f2ea592f8e802a29b2e16aca5808c7bc423b1fd4..883cc5430558889cd44277becf3c77ec2e619dc8 100644 --- a/paddle/operators/box_coder_op.cu +++ b/paddle/operators/box_coder_op.cu @@ -18,79 +18,85 @@ namespace operators { template __global__ void EncodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, - const T* target_box_data, int row, - int col, T* output) { + const T* target_box_data, const int row, + const int col, const int len, + T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; const int col_idx = idx % col; T prior_box_width = - prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; + prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; T prior_box_height = - prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; + prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; T prior_box_center_x = - (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; - T prior_box_center_y = - (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; + (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; + T prior_box_center_y = (prior_box_data[col_idx * len + 3] + + prior_box_data[col_idx * len + 1]) / + 2; T target_box_center_x = - (target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2; - T target_box_center_y = - (target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) / + (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / 2; + T target_box_center_y = (target_box_data[row_idx * len + 3] + + target_box_data[row_idx * len + 1]) / + 2; T target_box_width = - target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4]; + target_box_data[row_idx * len + 2] - target_box_data[row_idx * len]; T target_box_height = - target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1]; + target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1]; - output[idx * 4] = (target_box_center_x - prior_box_center_x) / - prior_box_width / prior_box_var_data[col_idx * 4]; - output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) / - prior_box_height / - prior_box_var_data[col_idx * 4 + 1]; - output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) / - prior_box_var_data[col_idx * 4 + 2]; - output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) / - prior_box_var_data[col_idx * 4 + 3]; + output[idx * len] = (target_box_center_x - prior_box_center_x) / + prior_box_width / prior_box_var_data[col_idx * len]; + output[idx * len + 1] = (target_box_center_y - prior_box_center_y) / + prior_box_height / + prior_box_var_data[col_idx * len + 1]; + output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) / + prior_box_var_data[col_idx * len + 2]; + output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) / + prior_box_var_data[col_idx * len + 3]; } } template __global__ void DecodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, - const T* target_box_data, int row, - int col, T* output) { + const T* target_box_data, const int row, + const int col, const int len, + T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; const int col_idx = idx % col; T prior_box_width = - prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; + prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; T prior_box_height = - prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; + prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; T prior_box_center_x = - (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; - T prior_box_center_y = - (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; + (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; + T prior_box_center_y = (prior_box_data[col_idx * len + 3] + + prior_box_data[col_idx * len + 1]) / + 2; - T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] * - target_box_data[row_idx * 4 + 2]) * + T target_box_width = exp(prior_box_var_data[col_idx * len + 2] * + target_box_data[row_idx * len + 2]) * prior_box_width; - T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] * - target_box_data[row_idx * 4 + 3]) * + T target_box_height = exp(prior_box_var_data[col_idx * len + 3] * + target_box_data[row_idx * len + 3]) * prior_box_height; - T target_box_center_x = prior_box_var_data[col_idx * 4] * - target_box_data[row_idx * 4] * prior_box_width + + T target_box_center_x = prior_box_var_data[col_idx * len] * + target_box_data[row_idx * len] * + prior_box_width + prior_box_center_x; - T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] * - target_box_data[row_idx * 4 + 1] * + T target_box_center_y = prior_box_var_data[col_idx * len + 1] * + target_box_data[row_idx * len + 1] * prior_box_height + prior_box_center_y; - output[idx * 4] = target_box_center_x - target_box_width / 2; - output[idx * 4 + 1] = target_box_center_y - target_box_height / 2; - output[idx * 4 + 2] = target_box_center_x + target_box_width / 2; - output[idx * 4 + 3] = target_box_center_y + target_box_height / 2; + output[idx * len] = target_box_center_x - target_box_width / 2; + output[idx * len + 1] = target_box_center_y - target_box_height / 2; + output[idx * len + 2] = target_box_center_x + target_box_width / 2; + output[idx * len + 3] = target_box_center_y + target_box_height / 2; } } @@ -111,6 +117,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel { } auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + auto len = prior_box->dims()[1]; int block = 512; int grid = (row * col + block - 1) / block; auto& device_ctx = context.cuda_device_context(); @@ -119,17 +126,17 @@ class BoxCoderCUDAKernel : public framework::OpKernel { const T* prior_box_var_data = prior_box_var->data(); const T* target_box_data = target_box->data(); - output_box->mutable_data({row, col, 4}, context.GetPlace()); + output_box->mutable_data({row, col, len}, context.GetPlace()); T* output = output_box->data(); auto code_type = GetBoxCodeType(context.Attr("code_type")); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( - prior_box_data, prior_box_var_data, target_box_data, row, col, + prior_box_data, prior_box_var_data, target_box_data, row, col, len, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( - prior_box_data, prior_box_var_data, target_box_data, row, col, + prior_box_data, prior_box_var_data, target_box_data, row, col, len, output); } }