Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bd561384
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bd561384
编写于
11月 29, 2017
作者:
S
sweetsky0901
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format code
上级
d9673cad
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
133 addition
and
140 deletion
+133
-140
paddle/operators/math/unpooling.cc
paddle/operators/math/unpooling.cc
+5
-12
paddle/operators/math/unpooling.cu
paddle/operators/math/unpooling.cu
+40
-47
paddle/operators/math/unpooling.h
paddle/operators/math/unpooling.h
+4
-5
paddle/operators/unpool_op.cc
paddle/operators/unpool_op.cc
+71
-63
paddle/operators/unpool_op.h
paddle/operators/unpool_op.h
+4
-4
python/paddle/v2/fluid/tests/test_unpool_op.py
python/paddle/v2/fluid/tests/test_unpool_op.py
+9
-9
未找到文件。
paddle/operators/math/unpooling.cc
浏览文件 @
bd561384
...
...
@@ -17,15 +17,13 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
math
{
// All tensors are in NCHW format
template
<
typename
T
>
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -40,7 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
for
(
int
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
i
=
0
;
i
<
input_feasize
;
++
i
)
{
int
index
=
indices_data
[
i
];
int
index
=
indices_data
[
i
];
PADDLE_ENFORCE
(
index
<
output_feasize
,
"err index in unpooling!"
);
output_data
[
index
]
=
input_data
[
i
];
}
...
...
@@ -51,9 +49,6 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
}
}
};
template
<
class
T
>
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
...
...
@@ -62,7 +57,7 @@ public:
const
framework
::
Tensor
&
indices
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
framework
::
Tensor
*
input_grad
)
{
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -89,12 +84,10 @@ public:
}
}
};
template
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Unpool2dMaxFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/unpooling.cu
浏览文件 @
bd561384
...
...
@@ -18,36 +18,33 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__global__
void
KernelUnpool2dMax
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
*
indices_data
,
__global__
void
KernelUnpool2dMax
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
*
indices_data
,
const
int
input_height
,
const
int
input_width
,
const
int
channels
,
T
*
output_data
,
const
int
output_height
,
const
int
output_width
)
{
int
in_n_stride
=
input_height
*
input_width
*
channels
;
int
in_c_stride
=
input_height
*
input_width
;
int
out_n_stride
=
output_height
*
output_width
*
channels
;
int
out_c_stride
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
in_n_stride
;
int
boffset
=
i
%
in_n_stride
;
int
cidx
=
boffset
/
in_c_stride
;
int
out_offset
=
bidx
*
out_n_stride
+
cidx
*
out_c_stride
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
out_c_stride
);
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
int
in_n_stride
=
input_height
*
input_width
*
channels
;
int
in_c_stride
=
input_height
*
input_width
;
int
out_n_stride
=
output_height
*
output_width
*
channels
;
int
out_c_stride
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
in_n_stride
;
int
boffset
=
i
%
in_n_stride
;
int
cidx
=
boffset
/
in_c_stride
;
int
out_offset
=
bidx
*
out_n_stride
+
cidx
*
out_c_stride
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
out_c_stride
);
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
}
template
<
typename
T
>
__global__
void
KernelUnpool2dMaxGrad
(
const
int
nthreads
,
const
T
*
input_data
,
__global__
void
KernelUnpool2dMaxGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
*
indices_data
,
const
int
input_height
,
const
int
input_width
,
...
...
@@ -57,32 +54,32 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
const
int
output_height
,
const
int
output_width
,
T
*
input_grad
)
{
int
in_n_stride
=
input_height
*
input_width
*
channels
;
int
in_c_stride
=
input_height
*
input_width
;
int
out_n_stride
=
output_height
*
output_width
*
channels
;
int
out_c_stride
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
in_n_stride
;
int
boffset
=
i
%
in_n_stride
;
int
cidx
=
boffset
/
in_c_stride
;
int
out_offset
=
bidx
*
out_n_stride
+
cidx
*
out_c_stride
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
out_c_stride
);
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
}
int
in_n_stride
=
input_height
*
input_width
*
channels
;
int
in_c_stride
=
input_height
*
input_width
;
int
out_n_stride
=
output_height
*
output_width
*
channels
;
int
out_c_stride
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
in_n_stride
;
int
boffset
=
i
%
in_n_stride
;
int
cidx
=
boffset
/
in_c_stride
;
int
out_offset
=
bidx
*
out_n_stride
+
cidx
*
out_c_stride
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
out_c_stride
);
input_grad
[
i
]
=
output_grad
[
out_offset
+
out_index
];
}
}
/*
* All tensors are in NCHW format.
*/
template
<
typename
T
>
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
output
)
{
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -93,7 +90,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
const
int
*
indices_data
=
indices
.
data
<
int
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
threads
=
1024
;
int
grid
=
(
input
.
numel
()
+
threads
-
1
)
/
threads
;
int
grid
=
(
input
.
numel
()
+
threads
-
1
)
/
threads
;
KernelUnpool2dMax
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
...
...
@@ -107,13 +104,13 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
*/
template
<
typename
T
>
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
framework
::
Tensor
*
input_grad
)
{
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -126,24 +123,20 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
threads
=
1024
;
int
grid
=
(
input
.
numel
()
+
threads
-
1
)
/
threads
;
int
grid
=
(
input
.
numel
()
+
threads
-
1
)
/
threads
;
KernelUnpool2dMaxGrad
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
input
.
numel
(),
input_data
,
indices_data
,
input_height
,
input_width
,
output_channels
,
output_data
,
output_grad_data
,
output_height
,
output_width
,
input_grad_data
);
output_height
,
output_width
,
input_grad_data
);
}
};
template
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2dMaxGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
Unpool2dMaxFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/unpooling.h
浏览文件 @
bd561384
...
...
@@ -22,22 +22,21 @@ namespace math {
template
<
typename
Place
,
typename
T
>
class
Unpool2dMaxFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
output
);
const
framework
::
Tensor
&
indices
,
framework
::
Tensor
*
output
);
};
template
<
typename
Place
,
class
T
>
class
Unpool2dMaxGradFunctor
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
indices
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
framework
::
Tensor
*
input_grad
);
framework
::
Tensor
*
input_grad
);
};
}
// namespace math
}
// namespace operators
...
...
paddle/operators/unpool_op.cc
浏览文件 @
bd561384
...
...
@@ -21,107 +21,115 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
Unpool2dOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."
);
AddInput
(
"Indices"
,
AddInput
(
"Indices"
,
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature."
);
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"(vector), the unpooling window size(height, width) "
"of unpooling operator."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector, default:{1, 1}), "
"strides (height, width) of unpooling operator."
)
.
SetDefault
({
1
,
1
});
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator."
)
.
SetDefault
({
0
,
0
});
AddAttr
<
std
::
string
>
(
"unpooling_type"
,
AddAttr
<
std
::
string
>
(
"unpooling_type"
,
"(string), unpooling type, can be
\"
max
\"
for max-unpooling "
)
.
InEnum
({
"max"
});
AddComment
(
R"DOC(
"Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
"Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf
)DOC"
);
}
};
int
OutputSize
(
int
input_size
,
int
ksize
,
int
padding
,
int
stride
)
{
int
output_size
=
(
input_size
-
1
)
*
stride
-
2
*
padding
+
ksize
;
int
output_size
=
(
input_size
-
1
)
*
stride
-
2
*
padding
+
ksize
;
return
output_size
;
}
class
UnpoolOp
:
public
framework
::
OperatorWithKernel
{
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of UnpoolOp"
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of UnpoolOp"
"should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Indices"
),
"Input(Indices) of UnpoolOp"
"should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Indices"
),
"Input(Indices) of UnpoolOp"
"should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of UnpoolOp should not be null."
);
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_y_dims
=
ctx
->
GetInputDim
(
"Indices"
);
std
::
string
unpooling_type
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"unpooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
PADDLE_ENFORCE
(
in_x_dims
.
size
()
==
4
,
"Unpooling intput must be of 4-dimensional."
);
PADDLE_ENFORCE_EQ
(
in_x_dims
,
in_y_dims
);
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]});
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
output_shape
.
push_back
(
OutputSize
(
in_x_dims
[
i
+
2
],
ksize
[
i
],
paddings
[
i
],
strides
[
i
]));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_shape
));
}
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_y_dims
=
ctx
->
GetInputDim
(
"Indices"
);
std
::
string
unpooling_type
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"unpooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
PADDLE_ENFORCE
(
in_x_dims
.
size
()
==
4
,
"Unpooling intput must be of 4-dimensional."
);
PADDLE_ENFORCE_EQ
(
in_x_dims
,
in_y_dims
);
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]});
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
output_shape
.
push_back
(
OutputSize
(
in_x_dims
[
i
+
2
],
ksize
[
i
],
paddings
[
i
],
strides
[
i
]));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_shape
));
}
};
class
UnpoolOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Input(X@GRAD) should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -129,10 +137,10 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
unpool
,
ops
::
UnpoolOp
,
ops
::
Unpool2dOpMaker
,
unpool_grad
,
ops
::
UnpoolOpGrad
);
REGISTER_OP_CPU_KERNEL
(
unpool
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
unpool_grad
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
unpool
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
UnpoolKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
unpool_grad
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
UnpoolGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/unpool_op.h
浏览文件 @
bd561384
...
...
@@ -27,7 +27,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
framework
::
Tensor
*
in_x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
framework
::
Tensor
*
in_y
=
context
.
Input
<
framework
::
Tensor
>
(
"Indices"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
std
::
string
unpooling_type
=
context
.
Attr
<
std
::
string
>
(
"unpooling_type"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
...
...
@@ -52,7 +52,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
*
out_grad
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
framework
::
Tensor
*
in_x_grad
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
unpooling_type
=
context
.
Attr
<
std
::
string
>
(
"unpooling_type"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
...
...
@@ -65,8 +65,8 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
zero
(
device_ctx
,
in_x_grad
,
static_cast
<
T
>
(
0
));
}
math
::
Unpool2dMaxGradFunctor
<
Place
,
T
>
unpool2d_max_backward
;
unpool2d_max_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
*
out
,
*
out
_grad
,
in_x_grad
);
unpool2d_max_backward
(
context
.
device_context
(),
*
in_x
,
*
in_y
,
*
out
,
*
out_grad
,
in_x_grad
);
}
};
...
...
python/paddle/v2/fluid/tests/test_unpool_op.py
浏览文件 @
bd561384
...
...
@@ -52,14 +52,16 @@ class TestUnpoolOp(OpTest):
c_start
+
arg
%
self
.
ksize
[
1
]
output
=
self
.
unpool2d_forward_naive
(
input
,
indices
,
self
.
ksize
,
\
self
.
strides
,
self
.
paddings
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
input
.
astype
(
'float32'
),
'Indices'
:
indices
.
astype
(
'int32'
)}
self
.
inputs
=
{
'X'
:
input
.
astype
(
'float32'
),
'Indices'
:
indices
.
astype
(
'int32'
)
}
self
.
attrs
=
{
'strides'
:
self
.
strides
,
'paddings'
:
self
.
paddings
,
'ksize'
:
self
.
ksize
,
'unpooling_type'
:
self
.
unpooling_type
,
}
'strides'
:
self
.
strides
,
'paddings'
:
self
.
paddings
,
'ksize'
:
self
.
ksize
,
'unpooling_type'
:
self
.
unpooling_type
,
}
self
.
outputs
=
{
'Out'
:
output
.
astype
(
'float32'
)}
def
test_check_output
(
self
):
...
...
@@ -76,7 +78,5 @@ class TestUnpoolOp(OpTest):
self
.
strides
=
[
2
,
2
]
self
.
paddings
=
[
0
,
0
]
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录