Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
e3c4b0da
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e3c4b0da
编写于
12月 13, 2018
作者:
S
SunGaofeng
提交者:
qingqing01
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
this is for psroi_pool op, test=develop (#14796)
* Add psroi_pool operator.
上级
30aad884
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
920 addition
and
0 deletion
+920
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/psroi_pool_op.cc
paddle/fluid/operators/psroi_pool_op.cc
+173
-0
paddle/fluid/operators/psroi_pool_op.cu
paddle/fluid/operators/psroi_pool_op.cu
+294
-0
paddle/fluid/operators/psroi_pool_op.h
paddle/fluid/operators/psroi_pool_op.h
+253
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+55
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+10
-0
python/paddle/fluid/tests/unittests/test_psroi_pool_op.py
python/paddle/fluid/tests/unittests/test_psroi_pool_op.py
+134
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
e3c4b0da
...
...
@@ -198,6 +198,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/psroi_pool_op.cc
0 → 100644
浏览文件 @
e3c4b0da
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
class
PSROIPoolOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), "
"the input of PSROIPoolOp. "
"The format of input tensor is NCHW. Where N is the batch size, "
"C is the number of input channels, "
"H is the height of the input feature map, and "
"W is the width."
);
AddInput
(
"ROIs"
,
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]. "
"where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. "
"The roi batch index can be calculated from LoD."
);
AddOutput
(
"Out"
,
"(Tensor), "
"the output of PSROIPoolOp is a 4-D Tensor with shape "
"(num_rois, output_channels, pooled_h, pooled_w)."
);
AddAttr
<
int
>
(
"output_channels"
,
"(int), "
"the number of channels of the output feature map. "
"For a task of C classes of objects, output_channels should be "
"(C + 1) for classification only."
);
AddAttr
<
float
>
(
"spatial_scale"
,
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling."
)
.
SetDefault
(
1.0
);
AddAttr
<
int
>
(
"pooled_height"
,
"(int, default 1), "
"the pooled output height."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"pooled_width"
,
"(int, default 1), "
"the pooled output width."
)
.
SetDefault
(
1
);
AddComment
(
R"Doc(
**PSROIPool Operator**
Position sensitive region of interest pooling (also known as PSROIPooling) is to perform
position-sensitive average pooling on regions of interest specified by input, takes as
input N position-sensitive score maps and a list of num_rois regions of interest.
PSROIPooling for R-FCN. Please refer to https://arxiv.org/abs/1605.06409 for more details.
)Doc"
);
}
};
class
PSROIPoolOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PSROIPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ROIs"
),
"Input(ROIs) of PSROIPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of PSROIPoolOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
rois_dims
=
ctx
->
GetInputDim
(
"ROIs"
);
PADDLE_ENFORCE
(
input_dims
.
size
()
==
4
,
"The format of input tensor is NCHW"
);
PADDLE_ENFORCE
(
rois_dims
.
size
()
==
2
,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"
);
PADDLE_ENFORCE
(
rois_dims
[
1
]
==
4
,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"
);
int
pooled_height
=
ctx
->
Attrs
().
Get
<
int
>
(
"pooled_height"
);
int
pooled_width
=
ctx
->
Attrs
().
Get
<
int
>
(
"pooled_width"
);
int
output_channels
=
ctx
->
Attrs
().
Get
<
int
>
(
"output_channels"
);
float
spatial_scale
=
ctx
->
Attrs
().
Get
<
float
>
(
"spatial_scale"
);
PADDLE_ENFORCE
(
input_dims
[
1
]
==
output_channels
*
pooled_height
*
pooled_width
,
"the channel of X(%d) should be equal to the product of "
"output_channels(%d), pooled_height(%d) and pooled_width(%d)"
,
input_dims
[
1
],
output_channels
,
pooled_height
,
pooled_width
);
PADDLE_ENFORCE_GT
(
pooled_height
,
0
,
"The pooled output height must be greater than 0"
);
PADDLE_ENFORCE_GT
(
pooled_width
,
0
,
"The pooled output width must be greater than 0"
);
PADDLE_ENFORCE_GT
(
output_channels
,
1
,
"The pooled output channels must greater than 1"
);
PADDLE_ENFORCE_GT
(
spatial_scale
,
0.0
f
,
"The spatial scale must greater than 0."
);
auto
out_dims
=
input_dims
;
out_dims
[
0
]
=
rois_dims
[
0
];
out_dims
[
1
]
=
output_channels
;
// input_dims[1] / (pooled_height * pooled_width);
out_dims
[
2
]
=
pooled_height
;
out_dims
[
3
]
=
pooled_width
;
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
PSROIPoolGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The gradient of Out should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"The gradient of X should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
psroi_pool
,
ops
::
PSROIPoolOp
,
ops
::
PSROIPoolOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
psroi_pool_grad
,
ops
::
PSROIPoolGradOp
);
REGISTER_OP_CPU_KERNEL
(
psroi_pool
,
ops
::
CPUPSROIPoolOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUPSROIPoolOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
psroi_pool_grad
,
ops
::
CPUPSROIPoolGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUPSROIPoolGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/psroi_pool_op.cu
0 → 100644
浏览文件 @
e3c4b0da
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaximumNumBlocks
=
4096
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaximumNumBlocks
);
}
template
<
typename
T
>
__global__
void
GPUPSROIPoolForward
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
input_rois
,
const
float
spatial_scale
,
const
int
input_channels
,
const
int
height
,
const
int
width
,
const
int
output_channels
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
*
rois_batch_id_data
,
T
*
output_data
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
// The output is in order (n, c, ph, pw)
int
pw
=
i
%
pooled_width
;
int
ph
=
(
i
/
pooled_width
)
%
pooled_height
;
int
c
=
(
i
/
pooled_width
/
pooled_height
)
%
output_channels
;
int
n
=
i
/
pooled_width
/
pooled_height
/
output_channels
;
// set roi_batch_id
int
roi_batch_id
=
rois_batch_id_data
[
n
];
// [start, end) interval for spatial sampling
const
T
*
offset_input_rois
=
input_rois
+
n
*
4
;
T
roi_start_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
0
]))
*
spatial_scale
;
T
roi_start_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
1
]))
*
spatial_scale
;
T
roi_end_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
2
])
+
1.
)
*
spatial_scale
;
T
roi_end_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
3
])
+
1.
)
*
spatial_scale
;
// Force too small ROIs to be 1x1
T
roi_height
=
max
(
roi_end_h
-
roi_start_h
,
(
T
)
0.1
);
// avoid 0
T
roi_width
=
max
(
roi_end_w
-
roi_start_w
,
(
T
)
0.1
);
// Compute w and h at input feature map
T
bin_size_h
=
roi_height
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
roi_width
/
static_cast
<
T
>
(
pooled_width
);
int
hstart
=
floor
(
bin_size_h
*
static_cast
<
T
>
(
ph
)
+
roi_start_h
);
int
wstart
=
floor
(
bin_size_w
*
static_cast
<
T
>
(
pw
)
+
roi_start_w
);
int
hend
=
ceil
(
bin_size_h
*
static_cast
<
T
>
(
ph
+
1
)
+
roi_start_h
);
int
wend
=
ceil
(
bin_size_w
*
static_cast
<
T
>
(
pw
+
1
)
+
roi_start_w
);
// Add roi offsets and clip to input boundaries
hstart
=
min
(
max
(
hstart
,
0
),
height
);
hend
=
min
(
max
(
hend
,
0
),
height
);
wstart
=
min
(
max
(
wstart
,
0
),
width
);
wend
=
min
(
max
(
wend
,
0
),
width
);
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
int
input_channel
=
(
c
*
pooled_height
+
ph
)
*
pooled_width
+
pw
;
const
T
*
offset_input_data
=
input_data
+
(
roi_batch_id
*
input_channels
+
input_channel
)
*
height
*
width
;
T
outsum
=
0
;
for
(
int
ih
=
hstart
;
ih
<
hend
;
++
ih
)
{
for
(
int
iw
=
wstart
;
iw
<
wend
;
++
iw
)
{
int
input_index
=
ih
*
width
+
iw
;
outsum
+=
offset_input_data
[
input_index
];
}
}
T
bin_area
=
static_cast
<
T
>
((
hend
-
hstart
)
*
(
wend
-
wstart
));
output_data
[
i
]
=
is_empty
?
0.
:
outsum
/
bin_area
;
}
}
template
<
typename
T
>
__global__
void
GPUPSROIPoolBackward
(
const
int
nthreads
,
const
T
*
input_rois
,
const
T
*
output_grad_data
,
const
float
spatial_scale
,
const
int
input_channels
,
const
int
height
,
const
int
width
,
const
int
output_channels
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
*
rois_batch_id_data
,
T
*
input_grad_data
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
// The output is in order (n, c, ph, pw)
int
pw
=
i
%
pooled_width
;
int
ph
=
(
i
/
pooled_width
)
%
pooled_height
;
int
c
=
(
i
/
pooled_width
/
pooled_height
)
%
output_channels
;
int
n
=
i
/
pooled_width
/
pooled_height
/
output_channels
;
// set roi_batch_id
int
roi_batch_id
=
rois_batch_id_data
[
n
];
int
input_channel
=
(
c
*
pooled_height
+
ph
)
*
pooled_width
+
pw
;
int
input_offset
=
(
roi_batch_id
*
input_channels
+
input_channel
)
*
height
*
width
;
T
*
offset_input_grad_data
=
input_grad_data
+
input_offset
;
// [start, end) interval for spatial sampling
const
T
*
offset_input_rois
=
input_rois
+
n
*
4
;
T
roi_start_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
0
]))
*
spatial_scale
;
T
roi_start_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
1
]))
*
spatial_scale
;
T
roi_end_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
2
])
+
1.
)
*
spatial_scale
;
T
roi_end_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
3
])
+
1.
)
*
spatial_scale
;
// Force too small ROIs to be 1x1
T
roi_height
=
max
(
roi_end_h
-
roi_start_h
,
(
T
)
0.1
);
// avoid 0
T
roi_width
=
max
(
roi_end_w
-
roi_start_w
,
(
T
)
0.1
);
// Compute w and h at input feature map
T
bin_size_h
=
roi_height
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
roi_width
/
static_cast
<
T
>
(
pooled_width
);
int
hstart
=
floor
(
bin_size_h
*
static_cast
<
T
>
(
ph
)
+
roi_start_h
);
int
wstart
=
floor
(
bin_size_w
*
static_cast
<
T
>
(
pw
)
+
roi_start_w
);
int
hend
=
ceil
(
bin_size_h
*
static_cast
<
T
>
(
ph
+
1
)
+
roi_start_h
);
int
wend
=
ceil
(
bin_size_w
*
static_cast
<
T
>
(
pw
+
1
)
+
roi_start_w
);
// Add roi offsets and clip to input boundaries
hstart
=
min
(
max
(
hstart
,
0
),
height
);
hend
=
min
(
max
(
hend
,
0
),
height
);
wstart
=
min
(
max
(
wstart
,
0
),
width
);
wend
=
min
(
max
(
wend
,
0
),
width
);
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
// Accumulate diff_val into input data
T
bin_area
=
static_cast
<
T
>
((
hend
-
hstart
)
*
(
wend
-
wstart
));
T
diff_val
=
is_empty
?
0.
:
output_grad_data
[
i
]
/
bin_area
;
for
(
int
ih
=
hstart
;
ih
<
hend
;
++
ih
)
{
for
(
int
iw
=
wstart
;
iw
<
wend
;
++
iw
)
{
int
input_index
=
ih
*
width
+
iw
;
platform
::
CudaAtomicAdd
(
offset_input_grad_data
+
input_index
,
diff_val
);
}
}
}
}
template
<
typename
Place
,
typename
T
>
class
GPUPSROIPoolOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
LoDTensor
>
(
"ROIs"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
output_channels
=
ctx
.
Attr
<
int
>
(
"output_channels"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
auto
in_dims
=
in
->
dims
();
int
batch_size
=
in_dims
[
0
];
int
input_channels
=
in_dims
[
1
];
int
height
=
in_dims
[
2
];
int
width
=
in_dims
[
3
];
PADDLE_ENFORCE_EQ
(
input_channels
,
output_channels
*
pooled_height
*
pooled_width
,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width"
);
int
rois_num
=
rois
->
dims
()[
0
];
if
(
rois_num
==
0
)
return
;
auto
rois_lod
=
rois
->
lod
().
back
();
int
rois_batch_size
=
rois_lod
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
rois_batch_size
,
batch_size
,
"The rois_batch_size and input(X) batch_size must be the same."
);
int
rois_num_with_lod
=
rois_lod
[
rois_batch_size
];
PADDLE_ENFORCE_EQ
(
rois_num
,
rois_num_with_lod
,
"The rois_num from input and lod must be the same."
);
// set rois batch id
framework
::
Tensor
rois_batch_id_list
;
rois_batch_id_list
.
Resize
({
rois_num
});
int
*
rois_batch_id_data
=
rois_batch_id_list
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
for
(
int
n
=
0
;
n
<
rois_batch_size
;
++
n
)
{
for
(
size_t
i
=
rois_lod
[
n
];
i
<
rois_lod
[
n
+
1
];
++
i
)
{
rois_batch_id_data
[
i
]
=
n
;
}
}
framework
::
Tensor
rois_batch_id_list_gpu
;
framework
::
TensorCopy
(
rois_batch_id_list
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
&
rois_batch_id_list_gpu
);
int
output_size
=
out
->
numel
();
int
blocks
=
NumBlocks
(
output_size
);
int
threads
=
kNumCUDAThreads
;
// call cuda kernel function
GPUPSROIPoolForward
<
T
><<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_size
,
in
->
data
<
T
>
(),
rois
->
data
<
T
>
(),
spatial_scale
,
input_channels
,
height
,
width
,
output_channels
,
pooled_height
,
pooled_width
,
rois_batch_id_list_gpu
.
data
<
int
>
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
};
template
<
typename
Place
,
typename
T
>
class
GPUPSROIPoolGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
LoDTensor
>
(
"ROIs"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
output_channels
=
ctx
.
Attr
<
int
>
(
"output_channels"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
int
rois_num
=
rois
->
dims
()[
0
];
int
input_channels
=
in
->
dims
()[
1
];
int
height
=
in
->
dims
()[
2
];
int
width
=
in
->
dims
()[
3
];
if
(
input_grad
)
{
// set roi batch id
framework
::
Tensor
rois_batch_id_list
;
rois_batch_id_list
.
Resize
({
rois_num
});
int
*
rois_batch_id_data
=
rois_batch_id_list
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
rois_lod
=
rois
->
lod
().
back
();
int
rois_batch_size
=
rois_lod
.
size
()
-
1
;
for
(
int
n
=
0
;
n
<
rois_batch_size
;
++
n
)
{
for
(
size_t
i
=
rois_lod
[
n
];
i
<
rois_lod
[
n
+
1
];
++
i
)
{
rois_batch_id_data
[
i
]
=
n
;
}
}
framework
::
Tensor
rois_batch_id_list_gpu
;
framework
::
TensorCopy
(
rois_batch_id_list
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
&
rois_batch_id_list_gpu
);
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
cuda_device_context
(),
input_grad
,
static_cast
<
T
>
(
0
));
int
output_grad_size
=
output_grad
->
numel
();
int
blocks
=
NumBlocks
(
output_grad_size
);
int
threads
=
kNumCUDAThreads
;
if
(
output_grad_size
>
0
)
{
GPUPSROIPoolBackward
<
T
><<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_size
,
rois
->
data
<
T
>
(),
output_grad
->
data
<
T
>
(),
spatial_scale
,
input_channels
,
height
,
width
,
output_channels
,
pooled_height
,
pooled_width
,
rois_batch_id_list_gpu
.
data
<
int
>
(),
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
psroi_pool
,
ops
::
GPUPSROIPoolOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUPSROIPoolOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
psroi_pool_grad
,
ops
::
GPUPSROIPoolGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUPSROIPoolGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/psroi_pool_op.h
0 → 100644
浏览文件 @
e3c4b0da
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
CPUPSROIPoolOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"ROIs"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
auto
output_channels
=
ctx
.
Attr
<
int
>
(
"output_channels"
);
auto
in_dims
=
in
->
dims
();
int
batch_size
=
in_dims
[
0
];
int
input_channels
=
in_dims
[
1
];
int
height
=
in_dims
[
2
];
int
width
=
in_dims
[
3
];
int
rois_num
=
rois
->
dims
()[
0
];
auto
in_stride
=
framework
::
stride
(
in_dims
);
auto
roi_stride
=
framework
::
stride
(
rois
->
dims
());
auto
out_stride
=
framework
::
stride
(
out
->
dims
());
const
T
*
input_data
=
in
->
data
<
T
>
();
framework
::
Tensor
rois_batch_id_list
;
rois_batch_id_list
.
Resize
({
rois_num
});
int
*
rois_batch_id_data
=
rois_batch_id_list
.
mutable_data
<
int
>
(
ctx
.
GetPlace
());
auto
rois_lod
=
rois
->
lod
().
back
();
int
rois_batch_size
=
rois_lod
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
rois_batch_size
,
batch_size
,
"the rois_batch_size and input(X) batch_size should be the same."
);
int
rois_num_with_lod
=
rois_lod
[
rois_batch_size
];
PADDLE_ENFORCE_EQ
(
rois_num_with_lod
,
rois_num
,
"the rois_num from input and lod must be the same"
);
PADDLE_ENFORCE_EQ
(
input_channels
,
output_channels
*
pooled_height
*
pooled_width
,
"the channels of input X should equal the product of "
"output_channels x pooled_height x pooled_width"
);
// calculate batch id index for each roi according to LoD
for
(
int
n
=
0
;
n
<
rois_batch_size
;
++
n
)
{
for
(
size_t
i
=
rois_lod
[
n
];
i
<
rois_lod
[
n
+
1
];
++
i
)
{
rois_batch_id_data
[
i
]
=
n
;
}
}
T
*
output_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
input_rois
=
rois
->
data
<
T
>
();
// calculate psroipooling, parallel processing can be implemented per ROI
for
(
int
n
=
0
;
n
<
rois_num
;
++
n
)
{
// set roi batch id
int
roi_batch_id
=
rois_batch_id_data
[
n
];
// [start, end) interval for spatial sampling
const
T
*
offset_input_rois
=
input_rois
+
n
*
4
;
T
roi_start_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
0
]))
*
spatial_scale
;
T
roi_start_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
1
]))
*
spatial_scale
;
T
roi_end_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
2
])
+
1.
)
*
spatial_scale
;
T
roi_end_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
3
])
+
1.
)
*
spatial_scale
;
// Force too small rois to be 1 x 1
T
roi_height
=
std
::
max
(
roi_end_h
-
roi_start_h
,
(
T
)
0.1
);
// avoid 0
T
roi_width
=
std
::
max
(
roi_end_w
-
roi_start_w
,
(
T
)
0.1
);
// Compute bin size w and h at input feature map
T
bin_size_h
=
roi_height
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
roi_width
/
static_cast
<
T
>
(
pooled_width
);
// calculate each pixel of the output feature map.
int
out_roi_offset
=
n
*
out_stride
[
0
];
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
// per category
int
out_plane_offset
=
out_roi_offset
+
c
*
out_stride
[
1
];
for
(
int
ph
=
0
;
ph
<
pooled_height
;
++
ph
)
{
int
out_row_offset
=
out_plane_offset
+
ph
*
out_stride
[
2
];
for
(
int
pw
=
0
;
pw
<
pooled_width
;
++
pw
)
{
// calculate w and h at input feature map
int
hstart
=
floor
(
static_cast
<
T
>
(
ph
)
*
bin_size_h
+
roi_start_h
);
int
wstart
=
floor
(
static_cast
<
T
>
(
pw
)
*
bin_size_w
+
roi_start_w
);
int
hend
=
ceil
(
static_cast
<
T
>
(
ph
+
1
)
*
bin_size_h
+
roi_start_h
);
int
wend
=
ceil
(
static_cast
<
T
>
(
pw
+
1
)
*
bin_size_w
+
roi_start_w
);
// Add roi offsets and clip to input boundaries
hstart
=
std
::
min
(
std
::
max
(
hstart
,
0
),
height
);
wstart
=
std
::
min
(
std
::
max
(
wstart
,
0
),
width
);
hend
=
std
::
min
(
std
::
max
(
hend
,
0
),
height
);
wend
=
std
::
min
(
std
::
max
(
wend
,
0
),
width
);
int
output_index
=
out_row_offset
+
pw
;
int
input_channel
=
(
c
*
pooled_height
+
ph
)
*
pooled_width
+
pw
;
int
input_plane_offset
=
roi_batch_id
*
in_stride
[
0
]
+
input_channel
*
in_stride
[
1
];
const
T
*
offset_input_data
=
input_data
+
input_plane_offset
;
T
out_sum
=
0.
;
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
for
(
int
ih
=
hstart
;
ih
<
hend
;
++
ih
)
{
for
(
int
iw
=
wstart
;
iw
<
wend
;
++
iw
)
{
int
input_index
=
ih
*
in_stride
[
2
]
+
iw
;
out_sum
+=
offset_input_data
[
input_index
];
}
}
T
bin_area
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
output_data
[
output_index
]
=
is_empty
?
0.
:
out_sum
/
bin_area
;
}
}
}
}
return
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CPUPSROIPoolGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"ROIs"
);
auto
*
output_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
input_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
output_channels
=
ctx
.
Attr
<
int
>
(
"output_channels"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
if
(
input_grad
)
{
auto
in_dims
=
in
->
dims
();
int
input_channels
=
in_dims
[
1
];
int
height
=
in_dims
[
2
];
int
width
=
in_dims
[
3
];
int
rois_num
=
rois
->
dims
()[
0
];
// set roi batch id
framework
::
Tensor
rois_batch_id_list
;
rois_batch_id_list
.
Resize
({
rois_num
});
int
*
rois_batch_id_data
=
rois_batch_id_list
.
mutable_data
<
int
>
(
ctx
.
GetPlace
());
auto
rois_lod
=
rois
->
lod
().
back
();
int
rois_batch_size
=
rois_lod
.
size
()
-
1
;
// calculate batch id index for each roi according to LoD
for
(
int
n
=
0
;
n
<
rois_batch_size
;
++
n
)
{
for
(
size_t
i
=
rois_lod
[
n
];
i
<
rois_lod
[
n
+
1
];
++
i
)
{
rois_batch_id_data
[
i
]
=
n
;
}
}
const
T
*
input_rois
=
rois
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// set gradient of X to be 0. before backpropagate.
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
set_zero
(
ctx
.
template
device_context
<
DeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
// backpropagate gradient per output pixel
int
output_grad_size
=
output_grad
->
numel
();
for
(
int
i
=
0
;
i
<
output_grad_size
;
++
i
)
{
// The output is in order (n, c, ph, pw)
int
pw
=
i
%
pooled_width
;
int
ph
=
(
i
/
pooled_width
)
%
pooled_height
;
int
c
=
(
i
/
pooled_width
/
pooled_height
)
%
output_channels
;
int
n
=
i
/
pooled_width
/
pooled_height
/
output_channels
;
// set roi_batch_id
int
roi_batch_id
=
rois_batch_id_data
[
n
];
int
input_channel
=
(
c
*
pooled_height
+
ph
)
*
pooled_width
+
pw
;
int
input_offset
=
(
roi_batch_id
*
input_channels
+
input_channel
)
*
height
*
width
;
T
*
offset_input_grad_data
=
input_grad_data
+
input_offset
;
// [start, end) interval for spatial sampling
const
T
*
offset_input_rois
=
input_rois
+
n
*
4
;
T
roi_start_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
0
]))
*
spatial_scale
;
T
roi_start_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
1
]))
*
spatial_scale
;
T
roi_end_w
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
2
])
+
1.
)
*
spatial_scale
;
T
roi_end_h
=
static_cast
<
T
>
(
round
(
offset_input_rois
[
3
])
+
1.
)
*
spatial_scale
;
// Force too small ROIs to be 1x1
T
roi_height
=
std
::
max
(
roi_end_h
-
roi_start_h
,
(
T
)
0.1
);
// avoid 0
T
roi_width
=
std
::
max
(
roi_end_w
-
roi_start_w
,
(
T
)
0.1
);
// Compute w and h at input feature map
T
bin_size_h
=
roi_height
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
roi_width
/
static_cast
<
T
>
(
pooled_width
);
int
hstart
=
floor
(
bin_size_h
*
static_cast
<
T
>
(
ph
)
+
roi_start_h
);
int
wstart
=
floor
(
bin_size_w
*
static_cast
<
T
>
(
pw
)
+
roi_start_w
);
int
hend
=
ceil
(
bin_size_h
*
static_cast
<
T
>
(
ph
+
1
)
+
roi_start_h
);
int
wend
=
ceil
(
bin_size_w
*
static_cast
<
T
>
(
pw
+
1
)
+
roi_start_w
);
// Add roi offsets and clip to input boundaries
hstart
=
std
::
min
(
std
::
max
(
hstart
,
0
),
height
);
hend
=
std
::
min
(
std
::
max
(
hend
,
0
),
height
);
wstart
=
std
::
min
(
std
::
max
(
wstart
,
0
),
width
);
wend
=
std
::
min
(
std
::
max
(
wend
,
0
),
width
);
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
// Accumulate diff_val into input data
T
bin_area
=
static_cast
<
T
>
((
hend
-
hstart
)
*
(
wend
-
wstart
));
T
diff_val
=
is_empty
?
0.
:
output_grad_data
[
i
]
/
bin_area
;
for
(
int
ih
=
hstart
;
ih
<
hend
;
++
ih
)
{
for
(
int
iw
=
wstart
;
iw
<
wend
;
++
iw
)
{
int
input_index
=
ih
*
width
+
iw
;
offset_input_grad_data
[
input_index
]
+=
diff_val
;
}
}
}
}
return
;
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
e3c4b0da
...
...
@@ -173,6 +173,7 @@ __all__ = [
'merge_selected_rows'
,
'get_tensor_from_selected_rows'
,
'lstm'
,
'psroi_pool'
,
]
kIgnoreIndex
=
-
100
...
...
@@ -9122,3 +9123,57 @@ def get_tensor_from_selected_rows(x, name=None):
outputs
=
{
'Out'
:
out
},
attrs
=
{})
return
out
@
templatedoc
()
def
psroi_pool
(
input
,
rois
,
output_channels
,
spatial_scale
,
pooled_height
,
pooled_width
,
name
=
None
):
"""
${comment}
Args:
input (Variable): ${x_comment}
rois (Variable): ROIs (Regions of Interest) to pool over.
output_channels (integer): ${output_channels_comment}
spatial_scale (float): ${spatial_scale_comment} Default: 1.0
pooled_height (integer): ${pooled_height_comment} Default: 1
pooled_width (integer): ${pooled_width_comment} Default: 1
name (str, default None): The name of this layer.
Returns:
Variable: ${out_comment}.
Examples:
.. code-block:: python
pool_out = fluid.layers.psroi_pool(input=x, rois=rois, 490, 1.0, 7, 7)
"""
helper
=
LayerHelper
(
'psroi_pool'
,
**
locals
())
# check attrs
if
not
isinstance
(
output_channels
,
int
):
raise
TypeError
(
"output_channels must be int type"
)
if
not
isinstance
(
spatial_scale
,
float
):
raise
TypeError
(
"spatial_scale must be float type"
)
if
not
isinstance
(
pooled_height
,
int
):
raise
TypeError
(
"pooled_height must be int type"
)
if
not
isinstance
(
pooled_width
,
int
):
raise
TypeError
(
"pooled_width must be int type"
)
dtype
=
helper
.
input_dtype
()
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
'psroi_pool'
,
inputs
=
{
'X'
:
input
,
'ROIs'
:
rois
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'output_channels'
:
output_channels
,
'spatial_scale'
:
spatial_scale
,
'pooled_height'
:
pooled_height
,
'pooled_width'
:
pooled_width
})
return
out
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
e3c4b0da
...
...
@@ -511,6 +511,16 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
output
)
print
(
str
(
program
))
def
test_psroi_pool
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
"x"
,
shape
=
[
245
,
30
,
30
],
dtype
=
"float32"
)
rois
=
layers
.
data
(
name
=
"rois"
,
shape
=
[
4
],
dtype
=
"float32"
,
lod_level
=
1
)
output
=
layers
.
psroi_pool
(
x
,
rois
,
5
,
0.25
,
7
,
7
)
self
.
assertIsNotNone
(
output
)
print
(
str
(
program
))
def
test_roi_align
(
self
):
program
=
Program
()
with
program_guard
(
program
):
...
...
python/paddle/fluid/tests/unittests/test_psroi_pool_op.py
0 → 100644
浏览文件 @
e3c4b0da
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
math
import
numpy
as
np
import
unittest
from
op_test
import
OpTest
class
TestPSROIPoolOp
(
OpTest
):
def
set_data
(
self
):
self
.
init_test_case
()
self
.
make_rois
()
self
.
calc_psroi_pool
()
self
.
inputs
=
{
'X'
:
self
.
x
,
'ROIs'
:
(
self
.
rois
[:,
1
:
5
],
self
.
rois_lod
)}
self
.
attrs
=
{
'output_channels'
:
self
.
output_channels
,
'spatial_scale'
:
self
.
spatial_scale
,
'pooled_height'
:
self
.
pooled_height
,
'pooled_width'
:
self
.
pooled_width
}
self
.
outputs
=
{
'Out'
:
self
.
outs
}
def
init_test_case
(
self
):
self
.
batch_size
=
3
self
.
channels
=
3
*
2
*
2
self
.
height
=
6
self
.
width
=
4
self
.
x_dim
=
[
self
.
batch_size
,
self
.
channels
,
self
.
height
,
self
.
width
]
self
.
spatial_scale
=
1.0
/
4.0
self
.
output_channels
=
3
self
.
pooled_height
=
2
self
.
pooled_width
=
2
self
.
x
=
np
.
random
.
random
(
self
.
x_dim
).
astype
(
'float32'
)
def
make_rois
(
self
):
rois
=
[]
self
.
rois_lod
=
[[]]
for
bno
in
range
(
self
.
batch_size
):
self
.
rois_lod
[
0
].
append
(
bno
+
1
)
for
i
in
range
(
bno
+
1
):
x1
=
np
.
random
.
random_integers
(
0
,
self
.
width
//
self
.
spatial_scale
-
self
.
pooled_width
)
y1
=
np
.
random
.
random_integers
(
0
,
self
.
height
//
self
.
spatial_scale
-
self
.
pooled_height
)
x2
=
np
.
random
.
random_integers
(
x1
+
self
.
pooled_width
,
self
.
width
//
self
.
spatial_scale
)
y2
=
np
.
random
.
random_integers
(
y1
+
self
.
pooled_height
,
self
.
height
//
self
.
spatial_scale
)
roi
=
[
bno
,
x1
,
y1
,
x2
,
y2
]
rois
.
append
(
roi
)
self
.
rois_num
=
len
(
rois
)
self
.
rois
=
np
.
array
(
rois
).
astype
(
'float32'
)
def
calc_psroi_pool
(
self
):
output_shape
=
(
self
.
rois_num
,
self
.
output_channels
,
self
.
pooled_height
,
self
.
pooled_width
)
out_data
=
np
.
zeros
(
output_shape
)
for
i
in
range
(
self
.
rois_num
):
roi
=
self
.
rois
[
i
]
roi_batch_id
=
int
(
roi
[
0
])
roi_start_w
=
round
(
roi
[
1
])
*
self
.
spatial_scale
roi_start_h
=
round
(
roi
[
2
])
*
self
.
spatial_scale
roi_end_w
=
(
round
(
roi
[
3
])
+
1.
)
*
self
.
spatial_scale
roi_end_h
=
(
round
(
roi
[
4
])
+
1.
)
*
self
.
spatial_scale
roi_height
=
max
(
roi_end_h
-
roi_start_h
,
0.1
)
roi_width
=
max
(
roi_end_w
-
roi_start_w
,
0.1
)
bin_size_h
=
roi_height
/
float
(
self
.
pooled_height
)
bin_size_w
=
roi_width
/
float
(
self
.
pooled_width
)
x_i
=
self
.
x
[
roi_batch_id
]
for
c
in
range
(
self
.
output_channels
):
for
ph
in
range
(
self
.
pooled_height
):
for
pw
in
range
(
self
.
pooled_width
):
hstart
=
int
(
math
.
floor
(
float
(
ph
)
*
bin_size_h
+
roi_start_h
))
wstart
=
int
(
math
.
floor
(
float
(
pw
)
*
bin_size_w
+
roi_start_w
))
hend
=
int
(
math
.
ceil
(
float
(
ph
+
1
)
*
bin_size_h
+
roi_start_h
))
wend
=
int
(
math
.
ceil
(
float
(
pw
+
1
)
*
bin_size_w
+
roi_start_w
))
hstart
=
min
(
max
(
hstart
,
0
),
self
.
height
)
hend
=
min
(
max
(
hend
,
0
),
self
.
height
)
wstart
=
min
(
max
(
wstart
,
0
),
self
.
width
)
wend
=
min
(
max
(
wend
,
0
),
self
.
width
)
c_in
=
(
c
*
self
.
pooled_height
+
ph
)
*
self
.
pooled_width
+
pw
is_empty
=
(
hend
<=
hstart
)
or
(
wend
<=
wstart
)
out_sum
=
0.
for
ih
in
range
(
hstart
,
hend
):
for
iw
in
range
(
wstart
,
wend
):
out_sum
+=
x_i
[
c_in
,
ih
,
iw
]
bin_area
=
(
hend
-
hstart
)
*
(
wend
-
wstart
)
out_data
[
i
,
c
,
ph
,
pw
]
=
0.
if
is_empty
else
(
out_sum
/
float
(
bin_area
))
self
.
outs
=
out_data
.
astype
(
'float32'
)
def
setUp
(
self
):
self
.
op_type
=
'psroi_pool'
self
.
set_data
()
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录