Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
79609288
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
79609288
编写于
11月 22, 2017
作者:
W
wanghaox
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add roi pool operator
上级
9216da3f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
604 addition
and
0 deletion
+604
-0
paddle/operators/roi_pool_op.cc
paddle/operators/roi_pool_op.cc
+126
-0
paddle/operators/roi_pool_op.cu
paddle/operators/roi_pool_op.cu
+265
-0
paddle/operators/roi_pool_op.h
paddle/operators/roi_pool_op.h
+213
-0
未找到文件。
paddle/operators/roi_pool_op.cc
0 → 100755
浏览文件 @
79609288
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/roi_pool_op.h"
namespace
paddle
{
namespace
operators
{
class
RoiPoolOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of RoiPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Rois"
),
"Input(Rois) of RoiPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of RoiPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Argmax"
),
"Output(Argmax) of RoiPoolOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
// Initialize the output's dims to maximum,
// and re-set to real dims by the value of Rois at kernel
ctx
->
SetOutputDim
(
"Out"
,
input_dims
);
}
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
RoiPoolGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The gradient of Out should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutputs
(
framework
::
GradVarName
(
"X"
)),
"The gradient of X should not be null."
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
RoiPoolOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
RoiPoolOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor), "
"the input of RoiPoolOp."
);
AddInput
(
"Rois"
,
"(Tensor), "
"RoIs (Regions of Interest) to pool over. "
"Should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …]."
);
AddOutput
(
"Out"
,
"(Tensor), "
"RoI pooled output 4-D tensor of shape "
"(num_rois, channels, pooled_h, pooled_w)."
);
AddOutput
(
"Argmax"
,
"(Tensor), "
"Argmaxes corresponding to indices in X used "
"for gradient computation. Only output "
"if arg “is_test” is false."
).
AsIntermediate
();
AddAttr
<
float
>
(
"spatial_scale"
,
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling."
)
.
SetDefault
(
1.0
);
AddAttr
<
int
>
(
"pooled_height"
,
"(int, default 1), "
"The pooled output height."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
"pooled_width"
,
"(int, default 1), "
"The pooled output width."
)
.
SetDefault
(
1
);
AddComment
(
R"DOC(
RoiPool operator
ROI Pooling for Faster-RCNN. The link below is a further introduction:
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
roi_pool
,
ops
::
RoiPoolOp
,
ops
::
RoiPoolOpMaker
,
roi_pool_grad
,
ops
::
RoiPoolGradOp
);
REGISTER_OP_CPU_KERNEL
(
roi_pool
,
ops
::
CPURoiPoolOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
roi_pool_grad
,
ops
::
CPURoiPoolGradOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/roi_pool_op.cu
0 → 100755
浏览文件 @
79609288
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/cuda_helper.h"
#include "paddle/operators/roi_pool_op.h"
namespace
paddle
{
namespace
operators
{
#define FLT_MAX __FLT_MAX__
constexpr
int
PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS
=
512
;
constexpr
int
PADDLE_OPERATORS_ROIPOOL_MAXIMUM_NUM_BLOCKS
=
4096
;
inline
int
PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS
(
const
int
N
)
{
return
std
::
min
((
N
+
PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS
-
1
)
/
PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS
,
PADDLE_OPERATORS_ROIPOOL_MAXIMUM_NUM_BLOCKS
);
}
template
<
typename
T
>
__global__
void
GPURoiPoolForward
(
const
int
nthreads
,
const
T
*
input_data
,
const
int64_t
*
input_rois
,
const
float
spatial_scale
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
,
T
*
output_data
,
int64_t
*
argmax_data
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
size_t
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
channels
;
const
int64_t
*
offset_input_rois
=
input_rois
+
n
*
5
;
int
roi_batch_ind
=
offset_input_rois
[
0
];
int
roi_start_w
=
round
(
offset_input_rois
[
1
]
*
spatial_scale
);
int
roi_start_h
=
round
(
offset_input_rois
[
2
]
*
spatial_scale
);
int
roi_end_w
=
round
(
offset_input_rois
[
3
]
*
spatial_scale
);
int
roi_end_h
=
round
(
offset_input_rois
[
4
]
*
spatial_scale
);
int
roi_width
=
max
(
roi_end_w
-
roi_start_w
+
1
,
1
);
int
roi_height
=
max
(
roi_end_h
-
roi_start_h
+
1
,
1
);
T
bin_size_h
=
static_cast
<
T
>
(
roi_height
)
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
static_cast
<
T
>
(
roi_width
)
/
static_cast
<
T
>
(
pooled_width
);
int
hstart
=
static_cast
<
int
>
(
floor
(
static_cast
<
T
>
(
ph
)
*
bin_size_h
));
int
wstart
=
static_cast
<
int
>
(
floor
(
static_cast
<
T
>
(
pw
)
*
bin_size_w
));
int
hend
=
static_cast
<
int
>
(
ceil
(
static_cast
<
T
>
(
ph
+
1
)
*
bin_size_h
));
int
wend
=
static_cast
<
int
>
(
ceil
(
static_cast
<
T
>
(
pw
+
1
)
*
bin_size_w
));
hstart
=
min
(
max
(
hstart
+
roi_start_h
,
0
),
height
);
hend
=
min
(
max
(
hend
+
roi_start_h
,
0
),
height
);
wstart
=
min
(
max
(
wstart
+
roi_start_w
,
0
),
width
);
wend
=
min
(
max
(
wend
+
roi_start_w
,
0
),
width
);
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
T
maxval
=
is_empty
?
0
:
-
FLT_MAX
;
int
maxidx
=
-
1
;
const
T
*
offset_input_data
=
input_data
+
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
int
input_data_index
=
h
*
width
+
w
;
if
(
offset_input_data
[
input_data_index
]
>
maxval
)
{
maxval
=
offset_input_data
[
input_data_index
];
maxidx
=
input_data_index
;
}
}
}
output_data
[
index
]
=
maxval
;
if
(
argmax_data
)
{
argmax_data
[
index
]
=
maxidx
;
}
}
}
template
<
typename
T
>
__global__
void
GPURoiPoolBackward
(
const
int
nthreads
,
const
int64_t
*
input_rois
,
const
T
*
output_grad
,
const
int64_t
*
argmax_data
,
const
int
num_rois
,
const
float
spatial_scale
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
,
T
*
input_grad
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
channels
;
const
int64_t
*
offset_input_rois
=
input_rois
+
n
*
5
;
int
roi_batch_ind
=
offset_input_rois
[
0
];
int
input_offset
=
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
int
output_offset
=
(
n
*
channels
+
c
)
*
pooled_height
*
pooled_width
;
const
T
*
offset_output_grad
=
output_grad
+
output_offset
;
T
*
offset_input_grad
=
input_grad
+
input_offset
;
const
int64_t
*
offset_argmax_data
=
argmax_data
+
output_offset
;
int
argmax
=
offset_argmax_data
[
ph
*
pooled_width
+
pw
];
if
(
argmax
!=
-
1
)
{
platform
::
CudaAtomicAdd
(
offset_input_grad
+
argmax
,
static_cast
<
T
>
(
offset_output_grad
[
ph
*
pooled_width
+
pw
]));
}
}
}
template
<
typename
Place
,
typename
T
>
class
GPURoiPoolOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
Tensor
>
(
"Rois"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
argmax
=
ctx
.
Output
<
Tensor
>
(
"Argmax"
);
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
PADDLE_ENFORCE_GT
(
pooled_height
,
0
,
"The pooled output height must greater than 0"
);
PADDLE_ENFORCE_GT
(
pooled_width
,
0
,
"The pooled output width must greater than 0"
);
PADDLE_ENFORCE_GT
(
spatial_scale
,
0
,
"The spatial scale must greater than 0"
);
auto
in_dims
=
in
->
dims
();
auto
in_stride
=
framework
::
stride
(
in_dims
);
int
channels
=
in_dims
[
1
];
int
height
=
in_dims
[
2
];
int
width
=
in_dims
[
3
];
int
rois_num
=
rois
->
dims
()[
0
];
auto
out_dims
=
in_dims
;
out_dims
[
0
]
=
rois_num
;
out_dims
[
1
]
=
in_dims
[
1
];
out_dims
[
2
]
=
pooled_height
;
out_dims
[
3
]
=
pooled_width
;
out
->
Resize
(
out_dims
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
out
,
static_cast
<
T
>
(
0
));
argmax
->
Resize
(
out
->
dims
());
argmax
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
int64_t
>
set_init
;
set_init
(
ctx
.
device_context
(),
argmax
,
static_cast
<
int64_t
>
(
-
1
));
if
(
rois_num
==
0
)
return
;
int
output_size
=
out
->
numel
();
int
blocks
=
PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS
(
output_size
);
int
threads
=
PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS
;
GPURoiPoolForward
<
T
>
<<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_size
,
in
->
data
<
T
>
(),
rois
->
data
<
int64_t
>
(),
spatial_scale
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
argmax
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
()));
return
;
}
};
template
<
typename
Place
,
typename
T
>
class
GPURoiPoolGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
Tensor
>
(
"Rois"
);
auto
*
argmax
=
ctx
.
Input
<
Tensor
>
(
"Argmax"
);
auto
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
int
rois_num
=
rois
->
dims
()[
0
];
int
channels
=
in
->
dims
()[
1
];
int
height
=
in
->
dims
()[
2
];
int
width
=
in
->
dims
()[
3
];
if
(
x_grad
)
{
x_grad
->
Resize
(
in
->
dims
());
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
x_grad
,
static_cast
<
T
>
(
0
));
int
output_grad_size
=
out_grad
->
numel
();
int
blocks
=
PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS
(
output_grad_size
);
int
threads
=
PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS
;
if
(
output_grad_size
>
0
)
{
GPURoiPoolBackward
<
T
>
<<<
blocks
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
output_grad_size
,
rois
->
data
<
int64_t
>
(),
out_grad
->
data
<
T
>
(),
argmax
->
data
<
int64_t
>
(),
rois_num
,
spatial_scale
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
return
;
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
roi_pool
,
ops
::
GPURoiPoolOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
roi_pool_grad
,
ops
::
GPURoiPoolGradOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/roi_pool_op.h
0 → 100755
浏览文件 @
79609288
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoD
=
framework
::
LoD
;
template
<
typename
Place
,
typename
T
>
class
CPURoiPoolOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
Tensor
>
(
"Rois"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
argmax
=
ctx
.
Output
<
Tensor
>
(
"Argmax"
);
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
auto
spatial_scale
=
ctx
.
Attr
<
float
>
(
"spatial_scale"
);
PADDLE_ENFORCE_GT
(
pooled_height
,
0
,
"The pooled output height must greater than 0"
);
PADDLE_ENFORCE_GT
(
pooled_width
,
0
,
"The pooled output width must greater than 0"
);
PADDLE_ENFORCE_GT
(
spatial_scale
,
0
,
"The spatial scale must greater than 0"
);
auto
in_dims
=
in
->
dims
();
int
batch_size
=
in_dims
[
0
];
int
channels
=
in_dims
[
1
];
int
height
=
in_dims
[
2
];
int
width
=
in_dims
[
3
];
int
rois_num
=
rois
->
dims
()[
0
];
auto
out_dims
=
in_dims
;
out_dims
[
0
]
=
rois_num
;
out_dims
[
1
]
=
channels
;
out_dims
[
2
]
=
pooled_height
;
out_dims
[
3
]
=
pooled_width
;
out
->
Resize
(
out_dims
);
argmax
->
Resize
(
out
->
dims
());
auto
in_stride
=
framework
::
stride
(
in_dims
);
auto
argmax_stride
=
framework
::
stride
(
argmax
->
dims
());
auto
roi_stride
=
framework
::
stride
(
rois
->
dims
());
auto
out_stride
=
framework
::
stride
(
out_dims
);
const
T
*
input_data
=
in
->
data
<
T
>
();
const
int64_t
*
rois_data
=
rois
->
data
<
int64_t
>
();
T
*
output_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
*
argmax_data
=
argmax
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
out
,
static_cast
<
T
>
(
0
));
math
::
SetConstant
<
Place
,
int64_t
>
set_init
;
set_init
(
ctx
.
device_context
(),
argmax
,
static_cast
<
int64_t
>
(
-
1
));
for
(
int
n
=
0
;
n
<
rois_num
;
++
n
)
{
int
roi_batch_id
=
rois_data
[
0
];
PADDLE_ENFORCE_GE
(
roi_batch_id
,
0
);
PADDLE_ENFORCE_LT
(
roi_batch_id
,
batch_size
);
rois_data
+=
roi_stride
[
0
];
}
rois_data
=
rois
->
data
<
int64_t
>
();
for
(
int
n
=
0
;
n
<
rois_num
;
++
n
)
{
int
roi_batch_id
=
rois_data
[
0
];
int
roi_start_w
=
round
(
rois_data
[
1
]
*
spatial_scale
);
int
roi_start_h
=
round
(
rois_data
[
2
]
*
spatial_scale
);
int
roi_end_w
=
round
(
rois_data
[
3
]
*
spatial_scale
);
int
roi_end_h
=
round
(
rois_data
[
4
]
*
spatial_scale
);
// Force malformed ROIs to be 1x1
int
roi_height
=
std
::
max
(
roi_end_h
-
roi_start_h
+
1
,
1
);
int
roi_width
=
std
::
max
(
roi_end_w
-
roi_start_w
+
1
,
1
);
const
float
bin_size_h
=
static_cast
<
float
>
(
roi_height
)
/
static_cast
<
float
>
(
pooled_height
);
const
float
bin_size_w
=
static_cast
<
float
>
(
roi_width
)
/
static_cast
<
float
>
(
pooled_width
);
const
float
*
batch_data
=
input_data
+
roi_batch_id
*
in_stride
[
0
];
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
pooled_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
pooled_width
;
++
pw
)
{
// Compute pooling region for this output unit:
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
int
hstart
=
static_cast
<
int
>
(
floor
(
static_cast
<
float
>
(
ph
)
*
bin_size_h
));
int
wstart
=
static_cast
<
int
>
(
floor
(
static_cast
<
float
>
(
pw
)
*
bin_size_w
));
int
hend
=
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
ph
+
1
)
*
bin_size_h
));
int
wend
=
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
pw
+
1
)
*
bin_size_w
));
hstart
=
std
::
min
(
std
::
max
(
hstart
+
roi_start_h
,
0
),
height
);
hend
=
std
::
min
(
std
::
max
(
hend
+
roi_start_h
,
0
),
height
);
wstart
=
std
::
min
(
std
::
max
(
wstart
+
roi_start_w
,
0
),
width
);
wend
=
std
::
min
(
std
::
max
(
wend
+
roi_start_w
,
0
),
width
);
const
int
pool_index
=
ph
*
pooled_width
+
pw
;
// Define an empty pooling region to be zero
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
output_data
[
pool_index
]
=
is_empty
?
0
:
-
__FLT_MAX__
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
const
int
index
=
h
*
width
+
w
;
if
(
batch_data
[
index
]
>
output_data
[
pool_index
])
{
output_data
[
pool_index
]
=
batch_data
[
index
];
argmax_data
[
pool_index
]
=
index
;
}
}
}
}
}
batch_data
+=
in_stride
[
1
];
output_data
+=
out_stride
[
1
];
argmax_data
+=
argmax_stride
[
1
];
}
// Increment ROI data pointer
rois_data
+=
roi_stride
[
0
];
}
return
;
}
};
template
<
typename
Place
,
typename
T
>
class
CPURoiPoolGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
Tensor
>
(
"Rois"
);
auto
*
argmax
=
ctx
.
Input
<
Tensor
>
(
"Argmax"
);
auto
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
if
(
x_grad
)
{
int
channels
=
in
->
dims
()[
1
];
auto
in_stride
=
framework
::
stride
(
in
->
dims
());
auto
roi_stride
=
framework
::
stride
(
rois
->
dims
());
const
int64_t
*
rois_data
=
rois
->
data
<
int64_t
>
();
int
rois_num
=
rois
->
dims
()[
0
];
T
*
x_grad_data
=
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
x_grad
,
static_cast
<
T
>
(
0
));
size_t
roi_offset
=
roi_stride
[
0
];
size_t
batch_offset
=
in_stride
[
0
];
size_t
channel_offset
=
in_stride
[
1
];
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
size_t
pool_channel_offset
=
pooled_height
*
pooled_width
;
const
int64_t
*
argmax_data
=
argmax
->
data
<
int64_t
>
();
for
(
size_t
n
=
0
;
n
<
rois_num
;
++
n
)
{
size_t
roi_batch_idx
=
rois_data
[
0
];
T
*
batch_grad_data
=
x_grad_data
+
batch_offset
*
roi_batch_idx
;
for
(
size_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
size_t
ph
=
0
;
ph
<
pooled_height
;
++
ph
)
{
for
(
size_t
pw
=
0
;
pw
<
pooled_width
;
++
pw
)
{
size_t
pool_index
=
ph
*
pooled_width
+
pw
;
if
(
argmax_data
[
pool_index
]
>=
0
)
{
size_t
index
=
static_cast
<
size_t
>
(
argmax_data
[
pool_index
]);
batch_grad_data
[
index
]
+=
out_grad_data
[
pool_index
];
}
}
}
batch_grad_data
+=
channel_offset
;
out_grad_data
+=
pool_channel_offset
;
argmax_data
+=
pool_channel_offset
;
}
rois_data
+=
roi_offset
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录