提交 d3e99aee 编写于 作者: Y Yuan Gao 提交者: qingqing01

add normalize switch to box_coder_op (#11129)

上级 e0a8c584
...@@ -91,6 +91,10 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,6 +91,10 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"the code type used with the target box") "the code type used with the target box")
.SetDefault("encode_center_size") .SetDefault("encode_center_size")
.InEnum({"encode_center_size", "decode_center_size"}); .InEnum({"encode_center_size", "decode_center_size"});
AddAttr<bool>("box_normalized",
"(bool, default true) "
"whether treat the priorbox as a noramlized box")
.SetDefault(true);
AddOutput("OutputBox", AddOutput("OutputBox",
"(LoDTensor or Tensor) " "(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of " "When code_type is 'encode_center_size', the output tensor of "
......
...@@ -20,15 +20,16 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -20,15 +20,16 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row,
const int col, const int len, const int col, const int len,
T* output) { const bool normalized, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int row_idx = idx / col; const int row_idx = idx / col;
const int col_idx = idx % col; const int col_idx = idx % col;
T prior_box_width = T prior_box_width = prior_box_data[col_idx * len + 2] -
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; prior_box_data[col_idx * len] + (normalized == false);
T prior_box_height = T prior_box_height = prior_box_data[col_idx * len + 3] -
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; prior_box_data[col_idx * len + 1] +
(normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] + T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
...@@ -41,10 +42,11 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -41,10 +42,11 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
T target_box_center_y = (target_box_data[row_idx * len + 3] + T target_box_center_y = (target_box_data[row_idx * len + 3] +
target_box_data[row_idx * len + 1]) / target_box_data[row_idx * len + 1]) /
2; 2;
T target_box_width = T target_box_width = target_box_data[row_idx * len + 2] -
target_box_data[row_idx * len + 2] - target_box_data[row_idx * len]; target_box_data[row_idx * len] + (normalized == false);
T target_box_height = T target_box_height = target_box_data[row_idx * len + 3] -
target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1]; target_box_data[row_idx * len + 1] +
(normalized == false);
output[idx * len] = (target_box_center_x - prior_box_center_x) / output[idx * len] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * len]; prior_box_width / prior_box_var_data[col_idx * len];
...@@ -63,14 +65,15 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, ...@@ -63,14 +65,15 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row,
const int col, const int len, const int col, const int len,
T* output) { const bool normalized, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int col_idx = idx % col; const int col_idx = idx % col;
T prior_box_width = T prior_box_width = prior_box_data[col_idx * len + 2] -
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; prior_box_data[col_idx * len] + (normalized == false);
T prior_box_height = T prior_box_height = prior_box_data[col_idx * len + 3] -
prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; prior_box_data[col_idx * len + 1] +
(normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] + T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
...@@ -93,8 +96,10 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, ...@@ -93,8 +96,10 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * len + 1] = target_box_center_y - target_box_height / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2;
output[idx * len + 2] = target_box_center_x + target_box_width / 2; output[idx * len + 2] =
output[idx * len + 3] = target_box_center_y + target_box_height / 2; target_box_center_x + target_box_width / 2 - (normalized == false);
output[idx * len + 3] =
target_box_center_y + target_box_height / 2 - (normalized == false);
} }
} }
...@@ -128,14 +133,15 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -128,14 +133,15 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
T* output = output_box->data<T>(); T* output = output_box->data<T>();
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output); normalized, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
output); normalized, output);
} }
} }
}; };
......
...@@ -34,7 +34,7 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void EncodeCenterSize(const framework::Tensor& target_box, void EncodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box, const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, const framework::Tensor& prior_box_var,
T* output) const { const bool normalized, T* output) const {
int64_t row = target_box.dims()[0]; int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0]; int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1]; int64_t len = prior_box.dims()[1];
...@@ -44,10 +44,11 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -44,10 +44,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
T prior_box_width = T prior_box_width = prior_box_data[j * len + 2] -
prior_box_data[j * len + 2] - prior_box_data[j * len]; prior_box_data[j * len] + (normalized == false);
T prior_box_height = T prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; prior_box_data[j * len + 1] +
(normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y = T prior_box_center_y =
...@@ -57,10 +58,11 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -57,10 +58,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2; (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
T target_box_center_y = T target_box_center_y =
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
T target_box_width = T target_box_width = target_box_data[i * len + 2] -
target_box_data[i * len + 2] - target_box_data[i * len]; target_box_data[i * len] + (normalized == false);
T target_box_height = T target_box_height = target_box_data[i * len + 3] -
target_box_data[i * len + 3] - target_box_data[i * len + 1]; target_box_data[i * len + 1] +
(normalized == false);
size_t offset = i * col * len + j * len; size_t offset = i * col * len + j * len;
output[offset] = (target_box_center_x - prior_box_center_x) / output[offset] = (target_box_center_x - prior_box_center_x) /
...@@ -79,7 +81,7 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -79,7 +81,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void DecodeCenterSize(const framework::Tensor& target_box, void DecodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box, const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, const framework::Tensor& prior_box_var,
T* output) const { const bool normalized, T* output) const {
int64_t row = target_box.dims()[0]; int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0]; int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1]; int64_t len = prior_box.dims()[1];
...@@ -91,10 +93,11 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -91,10 +93,11 @@ class BoxCoderKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len; size_t offset = i * col * len + j * len;
T prior_box_width = T prior_box_width = prior_box_data[j * len + 2] -
prior_box_data[j * len + 2] - prior_box_data[j * len]; prior_box_data[j * len] + (normalized == false);
T prior_box_height = T prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; prior_box_data[j * len + 1] +
(normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y = T prior_box_center_y =
...@@ -116,8 +119,10 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -116,8 +119,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output[offset] = target_box_center_x - target_box_width / 2; output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2; output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] = target_box_center_x + target_box_width / 2; output[offset + 2] =
output[offset + 3] = target_box_center_y + target_box_height / 2; target_box_center_x + target_box_width / 2 - (normalized == false);
output[offset + 3] =
target_box_center_y + target_box_height / 2 - (normalized == false);
} }
} }
} }
...@@ -139,11 +144,14 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -139,11 +144,14 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output_box->mutable_data<T>({row, col, len}, context.GetPlace()); output_box->mutable_data<T>({row, col, len}, context.GetPlace());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
T* output = output_box->data<T>(); T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(*target_box, *prior_box, *prior_box_var, output); EncodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(*target_box, *prior_box, *prior_box_var, output); DecodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
output);
} }
} }
}; };
......
...@@ -19,7 +19,8 @@ import math ...@@ -19,7 +19,8 @@ import math
from op_test import OpTest from op_test import OpTest
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
box_normalized):
prior_box_x = ( prior_box_x = (
(prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0]) (prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0])
prior_box_y = ( prior_box_y = (
...@@ -30,6 +31,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): ...@@ -30,6 +31,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
(prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0]) (prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0])
prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0],
prior_box_var.shape[1]) prior_box_var.shape[1])
if not box_normalized:
prior_box_height = prior_box_height + 1
prior_box_width = prior_box_width + 1
if (code_type == "EncodeCenterSize"): if (code_type == "EncodeCenterSize"):
target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape(
...@@ -40,6 +44,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): ...@@ -40,6 +44,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
target_box.shape[0], 1) target_box.shape[0], 1)
target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape( target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape(
target_box.shape[0], 1) target_box.shape[0], 1)
if not box_normalized:
target_box_height = target_box_height + 1
target_box_width = target_box_width + 1
output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \ output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \
prior_box_var[:,:,0] prior_box_var[:,:,0]
...@@ -64,9 +71,13 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): ...@@ -64,9 +71,13 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 1] = target_box_y - target_box_height / 2
output_box[:, :, 2] = target_box_x + target_box_width / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2
output_box[:, :, 3] = target_box_y + target_box_height / 2 output_box[:, :, 3] = target_box_y + target_box_height / 2
if not box_normalized:
output_box[:, :, 2] = output_box[:, :, 2] - 1
output_box[:, :, 3] = output_box[:, :, 3] - 1
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type,
box_normalized):
n = target_box.shape[0] n = target_box.shape[0]
m = prior_box.shape[0] m = prior_box.shape[0]
output_box = np.zeros((n, m, 4), dtype=np.float32) output_box = np.zeros((n, m, 4), dtype=np.float32)
...@@ -74,11 +85,11 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): ...@@ -74,11 +85,11 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type):
if (code_type == "EncodeCenterSize"): if (code_type == "EncodeCenterSize"):
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, box_coder(target_box[lod[i]:lod[i + 1], :], prior_box,
prior_box_var, output_box[lod[i]:lod[i + 1], :, :], prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
code_type) code_type, box_normalized)
elif (code_type == "DecodeCenterSize"): elif (code_type == "DecodeCenterSize"):
box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box, box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box,
prior_box_var, output_box[lod[i]:lod[i + 1], :, :], prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
code_type) code_type, box_normalized)
return output_box return output_box
...@@ -93,15 +104,19 @@ class TestBoxCoderOp(OpTest): ...@@ -93,15 +104,19 @@ class TestBoxCoderOp(OpTest):
prior_box_var = np.random.random((10, 4)).astype('float32') prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32') target_box = np.random.random((5, 10, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type) lod[0], code_type, box_normalized)
self.inputs = { self.inputs = {
'PriorBox': prior_box, 'PriorBox': prior_box,
'PriorBoxVar': prior_box_var, 'PriorBoxVar': prior_box_var,
'TargetBox': target_box, 'TargetBox': target_box,
} }
self.attrs = {'code_type': 'decode_center_size'} self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False
}
self.outputs = {'OutputBox': output_box} self.outputs = {'OutputBox': output_box}
...@@ -116,15 +131,16 @@ class TestBoxCoderOpWithLoD(OpTest): ...@@ -116,15 +131,16 @@ class TestBoxCoderOpWithLoD(OpTest):
prior_box_var = np.random.random((10, 4)).astype('float32') prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((20, 4)).astype('float32') target_box = np.random.random((20, 4)).astype('float32')
code_type = "EncodeCenterSize" code_type = "EncodeCenterSize"
box_normalized = True
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type) lod[0], code_type, box_normalized)
self.inputs = { self.inputs = {
'PriorBox': prior_box, 'PriorBox': prior_box,
'PriorBoxVar': prior_box_var, 'PriorBoxVar': prior_box_var,
'TargetBox': (target_box, lod), 'TargetBox': (target_box, lod),
} }
self.attrs = {'code_type': 'encode_center_size'} self.attrs = {'code_type': 'encode_center_size', 'box_normalized': True}
self.outputs = {'OutputBox': output_box} self.outputs = {'OutputBox': output_box}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册