提交 90f39b11 编写于 作者: J jerrywgz

Merge branch 'roialign' of https://github.com/jerrywgz/Paddle into roialign

...@@ -290,9 +290,6 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -290,9 +290,6 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
for (int n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = roi_batch_id_data[n]; int roi_batch_idx = roi_batch_id_data[n];
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
const T* batch_out_grad_data =
out_grad_data + roi_batch_idx * out_stride[0];
T roi_xmin = rois_data[0] * spatial_scale; T roi_xmin = rois_data[0] * spatial_scale;
T roi_ymin = rois_data[1] * spatial_scale; T roi_ymin = rois_data[1] * spatial_scale;
T roi_xmax = rois_data[2] * spatial_scale; T roi_xmax = rois_data[2] * spatial_scale;
...@@ -303,6 +300,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -303,6 +300,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
static_cast<T>(roi_height) / static_cast<T>(pooled_height); static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
T* batch_grad_data =
in_grad_data + roi_batch_idx * in_stride[0] + c * in_stride[1];
const T* batch_out_grad_data =
out_grad_data + n * out_stride[0] + c * out_stride[1];
for (int ph = 0; ph < pooled_height; ++ph) { for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) { for (int pw = 0; pw < pooled_width; ++pw) {
int pool_index = ph * pooled_width + pw; int pool_index = ph * pooled_width + pw;
...@@ -329,8 +330,6 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -329,8 +330,6 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
} }
} }
} }
batch_grad_data += in_stride[1];
batch_out_grad_data += out_stride[1];
} }
rois_data += roi_stride[0]; rois_data += roi_stride[0];
} }
......
...@@ -37,7 +37,7 @@ class TestROIAlignOp(OpTest): ...@@ -37,7 +37,7 @@ class TestROIAlignOp(OpTest):
self.outputs = {'Out': self.out_data} self.outputs = {'Out': self.out_data}
def init_test_case(self): def init_test_case(self):
self.batch_size = 1 self.batch_size = 3
self.channels = 3 self.channels = 3
self.height = 8 self.height = 8
self.width = 6 self.width = 6
...@@ -45,10 +45,10 @@ class TestROIAlignOp(OpTest): ...@@ -45,10 +45,10 @@ class TestROIAlignOp(OpTest):
# n, c, h, w # n, c, h, w
self.x_dim = (self.batch_size, self.channels, self.height, self.width) self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.spatial_scale = 1.0 / 1.0 self.spatial_scale = 1.0 / 2.0
self.pooled_height = 2 self.pooled_height = 2
self.pooled_width = 2 self.pooled_width = 2
self.sampling_ratio = 2 self.sampling_ratio = -1
self.x = np.random.random(self.x_dim).astype('float32') self.x = np.random.random(self.x_dim).astype('float32')
...@@ -57,7 +57,7 @@ class TestROIAlignOp(OpTest): ...@@ -57,7 +57,7 @@ class TestROIAlignOp(OpTest):
count = roi_bin_grid_h * roi_bin_grid_w count = roi_bin_grid_h * roi_bin_grid_w
bilinear_pos = np.zeros( bilinear_pos = np.zeros(
[self.channels, self.pooled_height, self.pooled_width, count, 4], [self.channels, self.pooled_height, self.pooled_width, count, 4],
np.int32) np.float32)
bilinear_w = np.zeros( bilinear_w = np.zeros(
[self.pooled_height, self.pooled_width, count, 4], np.float32) [self.pooled_height, self.pooled_width, count, 4], np.float32)
for ph in range(self.pooled_width): for ph in range(self.pooled_width):
...@@ -85,7 +85,7 @@ class TestROIAlignOp(OpTest): ...@@ -85,7 +85,7 @@ class TestROIAlignOp(OpTest):
if x_low >= self.width - 1: if x_low >= self.width - 1:
x = x_high = x_low = self.width - 1 x = x_high = x_low = self.width - 1
else: else:
x_high = x_low = self.width - 1 x_high = x_low + 1
ly = y - y_low ly = y - y_low
lx = x - x_low lx = x - x_low
hy = 1 - ly hy = 1 - ly
...@@ -107,8 +107,9 @@ class TestROIAlignOp(OpTest): ...@@ -107,8 +107,9 @@ class TestROIAlignOp(OpTest):
return bilinear_pos, bilinear_w return bilinear_pos, bilinear_w
def calc_roi_align(self): def calc_roi_align(self):
self.out_data = np.zeros((self.rois_num, self.channels, self.out_data = np.zeros(
self.pooled_height, self.pooled_width)) (self.rois_num, self.channels, self.pooled_height,
self.pooled_width)).astype('float32')
for i in range(self.rois_num): for i in range(self.rois_num):
roi = self.rois[i] roi = self.rois[i]
...@@ -118,14 +119,14 @@ class TestROIAlignOp(OpTest): ...@@ -118,14 +119,14 @@ class TestROIAlignOp(OpTest):
roi_ymin = roi[2] * self.spatial_scale roi_ymin = roi[2] * self.spatial_scale
roi_xmax = roi[3] * self.spatial_scale roi_xmax = roi[3] * self.spatial_scale
roi_ymax = roi[4] * self.spatial_scale roi_ymax = roi[4] * self.spatial_scale
roi_width = int(max(roi_xmax - roi_xmin, 1)) roi_width = max(roi_xmax - roi_xmin, 1)
roi_height = int(max(roi_ymax - roi_ymin, 1)) roi_height = max(roi_ymax - roi_ymin, 1)
bin_size_h = float(roi_height) / float(self.pooled_height) bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width) bin_size_w = float(roi_width) / float(self.pooled_width)
roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \ roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_height / pooled_height) math.ceil(roi_height / self.pooled_height)
roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \ roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_width / pooled_width) math.ceil(roi_width / self.pooled_width)
count = int(roi_bin_grid_h * roi_bin_grid_w) count = int(roi_bin_grid_h * roi_bin_grid_w)
pre_size = count * self.pooled_width * self.pooled_height pre_size = count * self.pooled_width * self.pooled_height
bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin, bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin,
...@@ -139,7 +140,7 @@ class TestROIAlignOp(OpTest): ...@@ -139,7 +140,7 @@ class TestROIAlignOp(OpTest):
def make_rois(self): def make_rois(self):
rois = [] rois = []
self.rois_lod = [[0]] self.rois_lod = [[]]
for bno in range(self.batch_size): for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1) self.rois_lod[0].append(bno + 1)
for i in range(bno + 1): for i in range(bno + 1):
...@@ -166,4 +167,4 @@ class TestROIAlignOp(OpTest): ...@@ -166,4 +167,4 @@ class TestROIAlignOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', max_relative_error=0.005)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册