Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
200f07c2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
200f07c2
编写于
11月 21, 2017
作者:
S
sweetsky0901
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add test
上级
ab03daa4
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
82 addition
and
29 deletion
+82
-29
paddle/operators/math/unpooling.cc
paddle/operators/math/unpooling.cc
+8
-8
paddle/operators/math/unpooling.cu
paddle/operators/math/unpooling.cu
+8
-11
paddle/operators/math/unpooling.h
paddle/operators/math/unpooling.h
+2
-2
paddle/operators/unpool_op.cc
paddle/operators/unpool_op.cc
+15
-6
paddle/operators/unpool_op.h
paddle/operators/unpool_op.h
+2
-2
python/paddle/v2/fluid/tests/test_unpool2d_op.py
python/paddle/v2/fluid/tests/test_unpool2d_op.py
+47
-0
未找到文件。
paddle/operators/math/unpooling.cc
浏览文件 @
200f07c2
...
@@ -20,7 +20,7 @@ namespace math {
...
@@ -20,7 +20,7 @@ namespace math {
// All tensors are in NCHW format
// All tensors are in NCHW format
template
<
typename
T
>
template
<
typename
T
>
class
Unpool2d
_
MaxFunctor
<
platform
::
CPUPlace
,
T
>
{
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -43,7 +43,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
...
@@ -43,7 +43,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
int
index
=
indices_data
[
i
];
int
index
=
indices_data
[
i
];
//
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
PADDLE_ENFORCE
(
index
<
output_feasize
,
"err index in unpooling!"
);
output_data
[
index
]
=
input_data
[
i
];
output_data
[
index
]
=
input_data
[
i
];
}
}
input_data
+=
input_feasize
;
input_data
+=
input_feasize
;
...
@@ -57,7 +57,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
...
@@ -57,7 +57,7 @@ class Unpool2d_MaxFunctor<platform::CPUPlace, T> {
template
<
class
T
>
template
<
class
T
>
class
Unpool2d
_
MaxGradFunctor
<
platform
::
CPUPlace
,
T
>
{
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -83,7 +83,7 @@ public:
...
@@ -83,7 +83,7 @@ public:
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
int
index
=
indices_data
[
i
];
int
index
=
indices_data
[
i
];
//
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
PADDLE_ENFORCE
(
index
<
output_feasize
,
"err index in unpooling!"
);
input_grad_data
[
i
]
=
output_grad_data
[
index
];
input_grad_data
[
i
]
=
output_grad_data
[
index
];
}
}
input_grad_data
+=
input_feasize
;
input_grad_data
+=
input_feasize
;
...
@@ -94,10 +94,10 @@ public:
...
@@ -94,10 +94,10 @@ public:
}
}
};
};
template
class
Unpool2d
_
MaxGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2d
_
MaxGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Unpool2d
_
MaxFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2d
_
MaxFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/math/unpooling.cu
浏览文件 @
200f07c2
...
@@ -30,12 +30,11 @@ __global__ void KernelUnpool2dMax(const int nthreads,
...
@@ -30,12 +30,11 @@ __global__ void KernelUnpool2dMax(const int nthreads,
const
int
output_width
)
{
const
int
output_width
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
// int output_feasize = output_height * output_width;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
*
output_height
*
output_width
;
*
output_height
*
output_width
;
int
out_index
=
indices_data
[
i
];
int
out_index
=
indices_data
[
i
];
// PADDLE_ENFORCE(out_index < output_feasize, "err index in unpooling!"
);
PADDLE_ASSERT
(
out_index
<
(
output_height
*
output_width
)
);
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
}
}
}
...
@@ -52,13 +51,11 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
...
@@ -52,13 +51,11 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
T
*
input_grad
)
{
T
*
input_grad
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
// int output_feasize = output_height * output_width;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
int
out_offset
=
i
/
(
input_height
*
input_width
)
\
*
output_height
*
output_width
;
*
output_height
*
output_width
;
int
out_index
=
indices_data
[
i
];
int
out_index
=
indices_data
[
i
];
// PADDLE_ENFORCE(out_index < output_feasize,
PADDLE_ASSERT
(
out_index
<
(
output_height
*
output_width
));
// "err index in unpooling!");
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
}
}
}
}
...
@@ -66,7 +63,7 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
...
@@ -66,7 +63,7 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
* All tensors are in NCHW format.
* All tensors are in NCHW format.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
class
Unpool2d
_
MaxFunctor
<
platform
::
GPUPlace
,
T
>
{
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -99,7 +96,7 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
...
@@ -99,7 +96,7 @@ class Unpool2d_MaxFunctor<platform::GPUPlace, T> {
* All tensors are in NCHW format.
* All tensors are in NCHW format.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
class
Unpool2d
_
MaxGradFunctor
<
platform
::
GPUPlace
,
T
>
{
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -135,11 +132,11 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
...
@@ -135,11 +132,11 @@ class Unpool2d_MaxGradFunctor<platform::GPUPlace, T> {
}
}
};
};
template
class
Unpool2d
_
MaxGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2d
_
MaxGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Unpool2d
_
MaxFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2d
_
MaxFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/math/unpooling.h
浏览文件 @
200f07c2
...
@@ -26,7 +26,7 @@ namespace math {
...
@@ -26,7 +26,7 @@ namespace math {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
Unpool2d
_
MaxFunctor
{
class
Unpool2dMaxFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
@@ -35,7 +35,7 @@ class Unpool2d_MaxFunctor {
...
@@ -35,7 +35,7 @@ class Unpool2d_MaxFunctor {
};
};
template
<
typename
Place
,
class
T
>
template
<
typename
Place
,
class
T
>
class
Unpool2d
_
MaxGradFunctor
{
class
Unpool2dMaxGradFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
input
,
...
...
paddle/operators/unpool_op.cc
浏览文件 @
200f07c2
...
@@ -49,11 +49,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -49,11 +49,15 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"paddings(height, width) of unpooling operator."
)
"paddings(height, width) of unpooling operator."
)
.
SetDefault
({
0
,
0
});
.
SetDefault
({
0
,
0
});
AddAttr
<
std
::
string
>
(
"unpoolingType"
,
AddAttr
<
std
::
string
>
(
"unpoolingType"
,
"(string), unpooling type, can be
\"
max
\"
for max-unpooling "
"(string), unpooling type, can be
\"
max
\"
for max-unpooling "
)
"and
\"
avg
\"
for average-unpooling."
)
.
InEnum
({
"max"
});
.
InEnum
({
"max"
,
"avg"
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
"input: the input Tensor to invert"
"indices: the indices given out by MaxPool2d"
"ksize – Size of the max pooling window."
"stride – Stride of the max pooling window."
"It is set to kernel_size by default."
"padding – Padding that was added to the input"
)DOC"
);
)DOC"
);
}
}
};
};
...
@@ -82,8 +86,13 @@ class UnpoolOp : public framework::OperatorWithKernel {
...
@@ -82,8 +86,13 @@ class UnpoolOp : public framework::OperatorWithKernel {
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
PADDLE_ENFORCE
(
in_x_dims
.
size
()
==
4
||
in_x_dims
.
size
()
==
5
,
PADDLE_ENFORCE
(
in_x_dims
.
size
()
==
4
,
"Unpooling intput should be 4-D or 5-D tensor."
);
"Unpooling intput should be 4-D."
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
PADDLE_ENFORCE
(
in_x_dims
[
i
]
==
in_y_dims
[
i
],
"X size must be eq Y size!"
);
}
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]});
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]});
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
...
...
paddle/operators/unpool_op.h
浏览文件 @
200f07c2
...
@@ -37,7 +37,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
...
@@ -37,7 +37,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
switch
(
ksize
.
size
())
{
switch
(
ksize
.
size
())
{
case
2
:
{
case
2
:
{
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
math
::
Unpool2d
_
MaxFunctor
<
Place
,
T
>
unpool2d_max_forward
;
math
::
Unpool2dMaxFunctor
<
Place
,
T
>
unpool2d_max_forward
;
unpool2d_max_forward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
out
);
unpool2d_max_forward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
out
);
}
}
}
break
;
}
break
;
...
@@ -70,7 +70,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
...
@@ -70,7 +70,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
switch
(
ksize
.
size
())
{
switch
(
ksize
.
size
())
{
case
2
:
{
case
2
:
{
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
math
::
Unpool2d
_
MaxGradFunctor
<
Place
,
T
>
unpool2d_max_backward
;
math
::
Unpool2dMaxGradFunctor
<
Place
,
T
>
unpool2d_max_backward
;
unpool2d_max_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
in_x_grad
,
unpool2d_max_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
in_x_grad
,
*
out
,
*
out_grad
);
*
out
,
*
out_grad
);
}
}
...
...
python/paddle/v2/fluid/tests/test_unpool2d_op.py
0 → 100644
浏览文件 @
200f07c2
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
maxout_forward_naive
(
input
,
groups
):
s0
,
s1
,
s2
,
s3
=
input
.
shape
return
np
.
ndarray
([
s0
,
s1
/
groups
,
groups
,
s2
,
s3
],
\
buffer
=
input
,
dtype
=
input
.
dtype
).
max
(
axis
=
(
2
))
class
TestUnpool2dOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"unpool2d"
self
.
init_test_case
()
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
output
=
self
.
MaxOut_forward_naive
(
input
,
self
.
groups
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
input
}
self
.
attrs
=
{
'strides'
:
self
.
strides
,
'paddings'
:
self
.
paddings
,
'ksize'
:
self
.
ksize
,
'unpooling_type'
:
self
.
pool_type
,
}
self
.
outputs
=
{
'Out'
:
output
.
astype
(
'float32'
)}
def
init_pool_type
(
self
):
self
.
pool_type
=
"max"
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
def
init_test_case
(
self
):
self
.
MaxOut_forward_naive
=
maxout_forward_naive
self
.
shape
=
[
100
,
6
,
2
,
2
]
self
.
groups
=
2
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录