Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5ec55e79
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ec55e79
编写于
10月 18, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
deconv impl
上级
80ebc8d5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
180 addition
and
16 deletion
+180
-16
paddle/operators/deconv2d_op.cc
paddle/operators/deconv2d_op.cc
+17
-16
paddle/operators/deconv2d_op.h
paddle/operators/deconv2d_op.h
+163
-0
未找到文件。
paddle/operators/deconv2d_op.cc
浏览文件 @
5ec55e79
...
@@ -31,22 +31,23 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -31,22 +31,23 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
input_channels
=
in_dims
[
1
];
int
output_channels
=
filter_dims
[
0
];
for
(
int
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
paddings
[
i
],
0
,
"No Padding allowed in deconv op."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
4
,
"Conv2DOp input should be 4-D."
);
}
PADDLE_ENFORCE_EQ
(
filter_dims
.
size
(),
4
,
"Conv2DOp filter should be 4-D."
);
PADDLE_ENFORCE_EQ
(
input_channels
,
filter_dims
[
1
]
*
groups
,
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
4
,
"Deconv2DOp input should be 4-D."
);
"The number of input channels should be equal to filter "
PADDLE_ENFORCE_EQ
(
filter_dims
.
size
(),
4
,
"Deconv2DOp filter should be 4-D."
);
"channels * groups."
);
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
filter_dims
[
0
],
PADDLE_ENFORCE_EQ
(
"input and kernel input dimension should be equal."
);
output_channels
%
groups
,
0
,
"The number of output channels should be divided by groups."
);
PADDLE_ENFORCE_EQ
(
groups
,
1
,
"The number of groups should be 1 in case of deconv op."
);
auto
output_height
=
(
in_dims
[
2
]
-
1
)
*
strides
[
0
]
+
filter_dims
[
2
];
auto
output_height
=
(
in_dims
[
2
]
-
1
)
*
strides
[
0
]
+
filter_dims
[
2
];
auto
output_width
=
(
in_dims
[
3
]
-
1
)
*
strides
[
1
]
+
filter_dims
[
3
];
auto
output_width
=
(
in_dims
[
3
]
-
1
)
*
strides
[
1
]
+
filter_dims
[
3
];
ctx
->
SetOutputDim
(
"Output"
,
ctx
->
SetOutputDim
(
"Output"
,
{
in_dims
[
0
],
filter_dims
[
0
],
output_height
,
output_width
});
{
in_dims
[
0
],
filter_dims
[
1
],
output_height
,
output_width
});
}
}
Deconv2DOpMaker
::
Deconv2DOpMaker
(
framework
::
OpProto
*
proto
,
Deconv2DOpMaker
::
Deconv2DOpMaker
(
framework
::
OpProto
*
proto
,
...
@@ -55,12 +56,12 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
...
@@ -55,12 +56,12 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
AddInput
(
AddInput
(
"Input"
,
"Input"
,
"The input tensor of deconvolution operator. "
"The input tensor of deconvolution operator. "
"The format of input tensor is N
CHW. Where N is batch size, C
is the "
"The format of input tensor is N
MHW. Where N is batch size, M
is the "
"number of channels, H and W is the height and width of image."
);
"number of
input
channels, H and W is the height and width of image."
);
AddInput
(
"Filter"
,
AddInput
(
"Filter"
,
"The filter tensor of deconvolution operator."
"The filter tensor of deconvolution operator."
"The format of the filter tensor is MCHW, where M is the number of "
"The format of the filter tensor is MCHW, where M is the number of "
"
output image channels, C is the number of in
put image channels, "
"
input image channels, C is the number of out
put image channels, "
"H and W is height and width of filter. "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in our "
"We enforce groups number == 1 and padding == 0 in our "
"deconvolution Scenario."
);
"deconvolution Scenario."
);
...
@@ -97,6 +98,6 @@ REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad,
...
@@ -97,6 +98,6 @@ REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad,
ops
::
Deconv2DOpGrad
);
ops
::
Deconv2DOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
deconv2d
,
ops
::
Gemm
ConvGrad
2DKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
deconv2d
,
ops
::
Gemm
Deconv
2DKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
deconv2d_grad
,
ops
::
GemmConv2DKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
deconv2d_grad
,
ops
::
GemmConv2DKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/deconv2d_op.h
浏览文件 @
5ec55e79
...
@@ -23,6 +23,7 @@ namespace paddle {
...
@@ -23,6 +23,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
DDim
=
framework
::
DDim
;
// Define Op classes in .h file so that other deconv
// Define Op classes in .h file so that other deconv
// operator implementations can reuse the code.
// operator implementations can reuse the code.
...
@@ -48,5 +49,167 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel {
...
@@ -48,5 +49,167 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
};
};
template
<
typename
Place
,
typename
T
>
class
GemmDeconv2DKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"Input"
);
// filter will be reshaped, so we do not use constant pointer here
Tensor
filter
=
*
context
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
// no paddings and groups allowed in deconv
int
N
=
input
->
dims
()[
0
];
int
M
=
input
->
dims
()[
1
];
int
H
=
input
->
dims
()[
2
];
int
W
=
input
->
dims
()[
3
];
int
K_H
=
filter
.
dims
()[
2
];
int
K_W
=
filter
.
dims
()[
3
];
int
C
=
output
->
dims
()[
1
];
// output channels
int
O_H
=
output
->
dims
()[
2
];
int
O_W
=
output
->
dims
()[
3
];
paddle
::
operators
::
math
::
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Place
,
T
>
col2im
;
// use col_shape in the im2col and col2im calculation
framework
::
DDim
col_shape
=
{
C
,
K_H
,
K_W
,
H
,
W
};
// use col_matrix_shape in the gemm calculation
framework
::
DDim
col_matrix_shape
=
{
M
*
K_H
*
K_W
,
H
*
W
};
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor
col_matrix
=
col
;
col_matrix
.
Resize
(
col_matrix_shape
);
DDim
output_shape
=
{
C
,
O_H
,
O_W
};
DDim
input_matrix_shape
=
{
M
,
H
*
W
};
DDim
filter_matrix_shape
=
{
M
,
C
*
K_H
*
K_W
};
filter
.
Resize
(
filter_matrix_shape
);
// deconvolution: gemm + col2im (similar to conv-backward on input)
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
// batch with size (M, H * W)
Tensor
input_batch
=
input
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
// output size: (C, O_H, O_W)
Tensor
output_batch
=
output
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
output_shape
);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter
,
true
,
input_batch
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
col2im
(
context
.
device_context
(),
output_batch
,
col_matrix
,
strides
[
0
],
strides
[
1
],
0
,
0
);
}
}
};
/*
template <typename Place, typename T>
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer
// but we should avoid
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// no paddings and groups allowed in deconv
int N = input->dims()[0];
int M = input->dims()[1];
int H = input->dims()[2];
int W = input->dims()[3];
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
int C = output->dims()[1]; // output channels
int O_H = output->dims()[2];
int O_W = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {M * K_H * K_W, H * W};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim filter_matrix_shape = {M, C* K_H * K_W};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (M, H * W)
Tensor input_batch =
input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// output size: (C, O_H, O_W)
Tensor output_batch =
output->Slice<T>(i, i + 1).Resize(output_shape);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix,
T(0.0));
col2im(context.device_context(), output_batch, col_matrix, strides[0],
strides[1], 0, 0);
}
}
};
*/
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录