提交 cee2e1b0 编写于 作者: J jerrywgz

refine code, test=develop

上级 a39240c3
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -95,47 +96,33 @@ __global__ void DecodeCenterSizeKernel( ...@@ -95,47 +96,33 @@ __global__ void DecodeCenterSizeKernel(
prior_box_data[prior_box_offset + 1] + prior_box_height / 2; prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
T target_box_width, target_box_height; T target_box_width, target_box_height;
T target_box_center_x, target_box_center_y; T target_box_center_x, target_box_center_y;
T box_var_x = T(1), box_var_y = T(1);
T box_var_w = T(1), box_var_h = T(1);
if (prior_box_var_data) { if (prior_box_var_data) {
int prior_var_offset = 0; int prior_var_offset = 0;
if (prior_box_var_size == 2) { if (prior_box_var_size == 2) {
prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; prior_var_offset = axis == 0 ? col_idx * len : row_idx * len;
} }
target_box_width = exp(prior_box_var_data[prior_var_offset + 2] * box_var_x = prior_box_var_data[prior_var_offset];
target_box_data[idx * len + 2]) * box_var_y = prior_box_var_data[prior_var_offset + 1];
prior_box_width; box_var_w = prior_box_var_data[prior_var_offset + 2];
target_box_height = exp(prior_box_var_data[prior_var_offset + 3] * box_var_h = prior_box_var_data[prior_var_offset + 3];
target_box_data[idx * len + 3]) *
prior_box_height;
target_box_center_x = prior_box_var_data[prior_var_offset] *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else if (var_size == 4) { } else if (var_size == 4) {
target_box_width = box_var_x = static_cast<T>(variance[0]);
exp(static_cast<T>(variance[2]) * target_box_data[idx * len + 2]) * box_var_y = static_cast<T>(variance[1]);
prior_box_width; box_var_w = static_cast<T>(variance[2]);
target_box_height = box_var_h = static_cast<T>(variance[3]);
exp(static_cast<T>(variance[3]) * target_box_data[idx * len + 3]) *
prior_box_height;
target_box_center_x = static_cast<T>(variance[0]) *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y = static_cast<T>(variance[1]) *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else {
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
exp(target_box_data[idx * len + 3]) * prior_box_height;
target_box_center_x =
target_box_data[idx * len] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[idx * len + 1] * prior_box_height +
prior_box_center_y;
} }
target_box_width =
exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height;
target_box_center_x =
box_var_x * target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[idx * len + 1] * prior_box_height +
prior_box_center_y;
output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * len + 1] = target_box_center_y - target_box_height / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2;
...@@ -177,9 +164,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -177,9 +164,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD."); "Only support 1 level of LoD.");
} }
const int var_size = static_cast<T>(variance.size()); const int var_size = static_cast<int>(variance.size());
thrust::device_vector<float> dev_variance(variance.begin(), variance.end());
const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type")); auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized"); bool normalized = context.Attr<bool>("box_normalized");
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
...@@ -194,6 +180,16 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -194,6 +180,16 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
int grid = (row * col + block - 1) / block; int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context(); auto& device_ctx = context.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(device_ctx);
int bytes = var_size * sizeof(float);
auto dev_var = allocator.Allocate(bytes);
float* dev_var_data = reinterpret_cast<float*>(dev_var->ptr());
auto cplace = platform::CPUPlace();
const auto gplace = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(gplace, dev_var_data, cplace, &variance[0], bytes,
device_ctx.stream());
output_box->mutable_data<T>({row, col, len}, context.GetPlace()); output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>(); T* output = output_box->data<T>();
......
...@@ -133,6 +133,8 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -133,6 +133,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T target_box_center_x = 0, target_box_center_y = 0; T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0; T target_box_width = 0, target_box_height = 0;
T box_var_x = T(1), box_var_y = T(1);
T box_var_w = T(1), box_var_h = T(1);
if (prior_box_var) { if (prior_box_var) {
int prior_var_offset = 0; int prior_var_offset = 0;
if (prior_box_var->dims().size() == 2) { if (prior_box_var->dims().size() == 2) {
...@@ -141,44 +143,26 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -141,44 +143,26 @@ class BoxCoderKernel : public framework::OpKernel<T> {
else if (axis == 1) else if (axis == 1)
prior_var_offset = i * len; prior_var_offset = i * len;
} }
target_box_center_x = prior_box_var_data[prior_var_offset] * box_var_x = prior_box_var_data[prior_var_offset];
target_box_data[offset] * prior_box_width + box_var_y = prior_box_var_data[prior_var_offset + 1];
prior_box_center_x; box_var_w = prior_box_var_data[prior_var_offset + 2];
target_box_center_y = prior_box_var_data[prior_var_offset + 1] * box_var_h = prior_box_var_data[prior_var_offset + 3];
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] *
target_box_data[offset + 2]) *
prior_box_width;
target_box_height =
std::exp(prior_box_var_data[prior_var_offset + 3] *
target_box_data[offset + 3]) *
prior_box_height;
} else if (!(variance.empty())) { } else if (!(variance.empty())) {
target_box_center_x = static_cast<T>(variance[0]) * box_var_x = static_cast<T>(variance[0]);
target_box_data[offset] * prior_box_width + box_var_y = static_cast<T>(variance[1]);
prior_box_center_x; box_var_w = static_cast<T>(variance[2]);
target_box_center_y = static_cast<T>(variance[1]) * box_var_h = static_cast<T>(variance[3]);
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(static_cast<T>(variance[2]) *
target_box_data[offset + 2]) *
prior_box_width;
target_box_height = std::exp(static_cast<T>(variance[3]) *
target_box_data[offset + 3]) *
prior_box_height;
} else {
target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(target_box_data[offset + 2]) * prior_box_width;
target_box_height =
std::exp(target_box_data[offset + 3]) * prior_box_height;
} }
target_box_center_x =
box_var_x * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width;
target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) *
prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2; output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2; output[offset + 1] = target_box_center_y - target_box_height / 2;
......
...@@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase): ...@@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase):
self.assertEqual(out.shape[-1], 6) self.assertEqual(out.shape[-1], 6)
print(str(program)) print(str(program))
def test_box_coder_api(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='z', shape=[4], dtype='float32', lod_level=1)
bcoder = layers.box_coder(
prior_box=x,
prior_box_var=[0.1, 0.2, 0.1, 0.2],
target_box=y,
code_type='encode_center_size')
self.assertIsNotNone(bcoder)
print(str(program))
def test_detection_api(self): def test_detection_api(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
...@@ -59,7 +72,7 @@ class TestDetection(unittest.TestCase): ...@@ -59,7 +72,7 @@ class TestDetection(unittest.TestCase):
iou = layers.iou_similarity(x=x, y=y) iou = layers.iou_similarity(x=x, y=y)
bcoder = layers.box_coder( bcoder = layers.box_coder(
prior_box=x, prior_box=x,
prior_box_var=[0.2, 0.3, 0.3, 0.2], prior_box_var=y,
target_box=z, target_box=z,
code_type='encode_center_size') code_type='encode_center_size')
self.assertIsNotNone(iou) self.assertIsNotNone(iou)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册