Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
45a8c9dd
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45a8c9dd
编写于
11月 21, 2017
作者:
S
sweetsky0901
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unpool2d make ok
上级
f638f910
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
59 addition
and
45 deletion
+59
-45
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+7
-0
paddle/operators/math/unpooling.cc
paddle/operators/math/unpooling.cc
+10
-16
paddle/operators/math/unpooling.cu
paddle/operators/math/unpooling.cu
+12
-9
paddle/operators/math/unpooling.h
paddle/operators/math/unpooling.h
+3
-2
paddle/operators/unpool_op.cc
paddle/operators/unpool_op.cc
+16
-9
paddle/operators/unpool_op.cu.cc
paddle/operators/unpool_op.cu.cc
+5
-2
paddle/operators/unpool_op.h
paddle/operators/unpool_op.h
+6
-7
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
45a8c9dd
...
@@ -80,6 +80,13 @@ function(op_library TARGET)
...
@@ -80,6 +80,13 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(pool2d);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(pool2d);
\n
"
)
endif
()
endif
()
# unpool_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"unpool_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(unpool2d);
\n
"
)
endif
()
# pool_cudnn_op contains several operators
# pool_cudnn_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"pool_cudnn_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"pool_cudnn_op"
)
set
(
pybind_flag 1
)
set
(
pybind_flag 1
)
...
...
paddle/operators/math/unpooling.cc
浏览文件 @
45a8c9dd
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/math/
maxout
ing.h"
#include "paddle/operators/math/
unpool
ing.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -20,7 +20,7 @@ namespace math {
...
@@ -20,7 +20,7 @@ namespace math {
// All tensors are in NCHW format
// All tensors are in NCHW format
template
<
typename
T
>
template
<
typename
T
>
class
Unpool2d_Max
_
Functor
<
platform
::
CPUPlace
,
T
>
{
class
Unpool2d_MaxFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -36,16 +36,14 @@ class Unpool2d_Max_Functor<platform::CPUPlace, T> {
...
@@ -36,16 +36,14 @@ class Unpool2d_Max_Functor<platform::CPUPlace, T> {
int
input_feasize
=
input_height
*
input_width
;
int
input_feasize
=
input_height
*
input_width
;
int
output_feasize
=
output_height
*
output_width
;
int
output_feasize
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
indices_data
=
indices
.
data
<
T
>
();
const
int
*
indices_data
=
indices
.
data
<
int
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
int
index
=
indices_data
[
i
];
int
index
=
indices_data
[
i
];
if
(
index
>
output_feasize
)
{
// PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
//抛一个异常!
}
output_data
[
index
]
=
input_data
[
i
];
output_data
[
index
]
=
input_data
[
i
];
}
}
input_data
+=
input_feasize
;
input_data
+=
input_feasize
;
...
@@ -70,26 +68,22 @@ public:
...
@@ -70,26 +68,22 @@ public:
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
3
];
int
input_feasize
=
input_height
*
input_width
;
int
input_feasize
=
input_height
*
input_width
;
int
output_feasize
=
output_height
*
output_width
;
int
output_feasize
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
int
*
indices_data
=
indices
.
data
<
int
>
();
const
T
*
indices_data
=
indices
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
f
=
0
;
f
<
input_feasize
;
++
f
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
int
index
=
indices_data
[
i
];
int
index
=
indices_data
[
i
];
if
(
index
>
output_feasize
)
{
// PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
//抛一个异常!
}
input_grad_data
[
i
]
=
output_grad_data
[
index
];
input_grad_data
[
i
]
=
output_grad_data
[
index
];
}
}
input_grad_data
+=
input_feasize
;
input_grad_data
+=
input_feasize
;
...
...
paddle/operators/math/unpooling.cu
浏览文件 @
45a8c9dd
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/math/
maxout
ing.h"
#include "paddle/operators/math/
unpool
ing.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -22,7 +22,7 @@ namespace math {
...
@@ -22,7 +22,7 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KernelUnpool2dMax
(
const
int
nthreads
,
__global__
void
KernelUnpool2dMax
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
input_data
,
const
T
*
indices_data
,
const
int
*
indices_data
,
const
int
input_height
,
const
int
input_height
,
const
int
input_width
,
const
int
input_width
,
T
*
output_data
,
T
*
output_data
,
...
@@ -30,16 +30,19 @@ __global__ void KernelUnpool2dMax(const int nthreads,
...
@@ -30,16 +30,19 @@ __global__ void KernelUnpool2dMax(const int nthreads,
const
int
output_width
)
{
const
int
output_width
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
// int output_feasize = output_height * output_width;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
*
output_height
*
output_width
;
*
output_height
*
output_width
;
int
out_index
=
indices_data
[
i
];
int
out_index
=
indices_data
[
i
];
// PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!");
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KernelUnpool2dMaxGrad
(
const
int
nthreads
,
__global__
void
KernelUnpool2dMaxGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
input_data
,
const
int
*
indices_data
,
const
int
input_height
,
const
int
input_height
,
const
int
input_width
,
const
int
input_width
,
const
T
*
output_data
,
const
T
*
output_data
,
...
@@ -49,10 +52,13 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
...
@@ -49,10 +52,13 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
T
*
input_grad
)
{
T
*
input_grad
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
// int output_feasize = output_height * output_width;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
*
output_height
*
output_width
;
*
output_height
*
output_width
;
int
out_index
=
indices_data
[
i
];
int
out_index
=
indices_data
[
i
];
// PADDLE_ENFORCE(out_index < output_feasize,
// "err index in unpooling!");
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
}
}
}
}
...
@@ -72,10 +78,8 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
...
@@ -72,10 +78,8 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
output_width
=
output
->
dims
()[
3
];
int
input_feasize
=
input_height
*
input_width
;
int
output_feasize
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
indices_data
=
indices
.
data
<
T
>
();
const
int
*
indices_data
=
indices
.
data
<
int
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
output
->
numel
();
int
nthreads
=
output
->
numel
();
...
@@ -99,19 +103,18 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
...
@@ -99,19 +103,18 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
input_grad
,
framework
::
Tensor
*
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
output_grad
)
{
int
groups
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
3
];
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
indices_data
=
indices
.
data
<
T
>
();
const
int
*
indices_data
=
indices
.
data
<
int
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/math/unpooling.h
浏览文件 @
45a8c9dd
...
@@ -26,7 +26,7 @@ namespace math {
...
@@ -26,7 +26,7 @@ namespace math {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
Unpool2d_Max
_
Functor
{
class
Unpool2d_MaxFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -35,10 +35,11 @@ class Unpool2d_Max_Functor {
...
@@ -35,10 +35,11 @@ class Unpool2d_Max_Functor {
};
};
template
<
typename
Place
,
class
T
>
template
<
typename
Place
,
class
T
>
class
Unpool2d_Max
_
GradFunctor
{
class
Unpool2d_MaxGradFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
input_grad
,
framework
::
Tensor
*
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
);
const
framework
::
Tensor
&
output_grad
);
...
...
paddle/operators/unpool_op.cc
浏览文件 @
45a8c9dd
...
@@ -20,7 +20,8 @@ using framework::Tensor;
...
@@ -20,7 +20,8 @@ using framework::Tensor;
class
Unpool2dOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Unpool2dOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
UnpoolOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
Unpool2dOpMaker
(
framework
::
OpProto
*
proto
,
\
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor) The input tensor of unpool operator. "
"(Tensor) The input tensor of unpool operator. "
...
@@ -39,10 +40,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -39,10 +40,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"(vector ), the unpooling window size(height, width) "
"(vector ), the unpooling window size(height, width) "
"of unpooling operator."
);
"of unpooling operator."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector, default:{1, 1}), "
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator."
)
"strides(height, width) of unpooling operator."
)
.
SetDefault
({
1
,
1
});
.
SetDefault
({
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector defalut:{0,0}), "
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator."
)
"paddings(height, width) of unpooling operator."
)
.
SetDefault
({
0
,
0
});
.
SetDefault
({
0
,
0
});
AddAttr
<
std
::
string
>
(
"unpoolingType"
,
AddAttr
<
std
::
string
>
(
"unpoolingType"
,
...
@@ -73,7 +76,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
...
@@ -73,7 +76,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
in_y_dims
=
ctx
->
GetInputDim
(
"Y"
);
std
::
string
unpooling_type
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"unpooling_type"
);
std
::
string
unpooling_type
=
\
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"unpooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
...
@@ -95,7 +99,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
...
@@ -95,7 +99,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(
X
) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(
Y
) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
...
@@ -109,8 +113,11 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
...
@@ -109,8 +113,11 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
unpool2d
,
ops
::
UnpoolOp
,
ops
::
Unpool2dOpMaker
,
unpool2d_grad
,
REGISTER_OP
(
unpool2d
,
ops
::
UnpoolOp
,
ops
::
Unpool2dOpMaker
,
unpool2d_grad
,
ops
::
UnpoolOpGrad
);
ops
::
UnpoolOpGrad
);
REGISTER_OP_CPU_KERNEL
(
unpool2d
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_CPU_KERNEL
(
unpool2d
,
float
>
);
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
unpool2d_grad
,
REGISTER_OP_CPU_KERNEL
(
unpool2d_grad
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
float
>
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/unpool_op.cu.cc
浏览文件 @
45a8c9dd
...
@@ -16,7 +16,10 @@
...
@@ -16,7 +16,10 @@
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
unpool2d
,
REGISTER_OP_GPU_KERNEL
(
unpool2d
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
UnpoolKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
unpool2d_grad
,
REGISTER_OP_GPU_KERNEL
(
unpool2d_grad
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
GPUPlace
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
float
>
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
paddle/operators/unpool_op.h
浏览文件 @
45a8c9dd
...
@@ -37,9 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> {
...
@@ -37,9 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> {
switch
(
ksize
.
size
())
{
switch
(
ksize
.
size
())
{
case
2
:
{
case
2
:
{
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
math
::
Unpool2d_Max_Functor
<
Place
,
T
>
unpool2d_max_forward
;
math
::
Unpool2d_MaxFunctor
<
Place
,
T
>
unpool2d_max_forward
;
unpool2d_max_forward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
unpool2d_max_forward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
out
);
ksize
,
strides
,
paddings
,
out
);
}
}
}
break
;
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D input."
);
}
default:
{
PADDLE_THROW
(
"Pool op only supports 2D input."
);
}
...
@@ -71,12 +70,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
...
@@ -71,12 +70,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
switch
(
ksize
.
size
())
{
switch
(
ksize
.
size
())
{
case
2
:
{
case
2
:
{
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
math
::
Unpool
GradFunctor
<
Place
,
T
>
maxout
_backward
;
math
::
Unpool
2d_MaxGradFunctor
<
Place
,
T
>
unpool2d_max
_backward
;
maxout_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
in_x_grad
,
*
out
,
unpool2d_max_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
in_x_grad
,
*
out_grad
,
ksize
,
strides
,
paddings
);
*
out
,
*
out_grad
);
}
}
}
break
;
}
break
;
default:
{
PADDLE_THROW
(
"
P
ool op only supports 2D input."
);
}
default:
{
PADDLE_THROW
(
"
Unp
ool op only supports 2D input."
);
}
}
}
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录