Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d97a732f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d97a732f
编写于
10月 19, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
deconv
上级
e59ca752
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
24 addition
and
16 deletion
+24
-16
paddle/operators/deconv2d_op.cc
paddle/operators/deconv2d_op.cc
+0
-4
paddle/operators/deconv2d_op.h
paddle/operators/deconv2d_op.h
+24
-12
未找到文件。
paddle/operators/deconv2d_op.cc
浏览文件 @
d97a732f
...
@@ -30,7 +30,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -30,7 +30,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
for
(
int
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
paddings
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
paddings
[
i
],
0
,
"No Padding allowed in deconv op."
);
PADDLE_ENFORCE_EQ
(
paddings
[
i
],
0
,
"No Padding allowed in deconv op."
);
...
@@ -41,9 +40,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -41,9 +40,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
filter_dims
[
0
],
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
filter_dims
[
0
],
"input and kernel input dimension should be equal."
);
"input and kernel input dimension should be equal."
);
PADDLE_ENFORCE_EQ
(
groups
,
1
,
"The number of groups should be 1 in case of deconv op."
);
auto
output_height
=
(
in_dims
[
2
]
-
1
)
*
strides
[
0
]
+
filter_dims
[
2
];
auto
output_height
=
(
in_dims
[
2
]
-
1
)
*
strides
[
0
]
+
filter_dims
[
2
];
auto
output_width
=
(
in_dims
[
3
]
-
1
)
*
strides
[
1
]
+
filter_dims
[
3
];
auto
output_width
=
(
in_dims
[
3
]
-
1
)
*
strides
[
1
]
+
filter_dims
[
3
];
ctx
->
SetOutputDim
(
"Output"
,
ctx
->
SetOutputDim
(
"Output"
,
...
...
paddle/operators/deconv2d_op.h
浏览文件 @
d97a732f
...
@@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
...
@@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
DDim
col_shape
=
{
C
,
K_H
,
K_W
,
H
,
W
};
DDim
col_shape
=
{
C
,
K_H
,
K_W
,
H
,
W
};
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
DDim
col_matrix_shape
=
{
M
*
K_H
*
K_W
,
H
*
W
};
DDim
col_matrix_shape
=
{
C
*
K_H
*
K_W
,
H
*
W
};
Tensor
col
;
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
...
@@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
...
@@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
// batch with size (M, H * W)
// batch with size (M, H * W)
Tensor
input_batch
=
input
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
Tensor
input_batch
=
input
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
// filter size: (M, C * K_H * K_W)
// output size: (C, O_H, O_W)
// output size: (C, O_H, O_W)
Tensor
output_batch
=
output
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
output_shape
);
Tensor
output_batch
=
output
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
output_shape
);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
// of shape (C * K_H * K_W, H * W)
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter
,
true
,
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter
,
true
,
...
@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
const
Tensor
*
output_grad
=
const
Tensor
*
output_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
// For filter, we do not use const pointer
// For filter, we do not use const pointer
b/c we will do reshape
// but we should avoid
// but we should avoid
modifying its value
Tensor
filter
=
*
context
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
filter
=
*
context
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
input_grad
=
Tensor
*
input_grad
=
...
@@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
int
O_H
=
output_grad
->
dims
()[
2
];
int
O_H
=
output_grad
->
dims
()[
2
];
int
O_W
=
output_grad
->
dims
()[
3
];
int
O_W
=
output_grad
->
dims
()[
3
];
//
Two functors required
to get to the right shape
//
Only im2col functor required for bp
to get to the right shape
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Place
,
T
>
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
Place
,
T
>
im2col
;
im2col
;
...
@@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
DDim
col_shape
=
{
C
,
K_H
,
K_W
,
H
,
W
};
DDim
col_shape
=
{
C
,
K_H
,
K_W
,
H
,
W
};
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
DDim
col_matrix_shape
=
{
C
*
K_H
*
K_W
,
H
*
W
};
DDim
col_matrix_shape
_f
=
{
C
*
H
*
W
,
K_H
*
K_
W
};
Tensor
col
;
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
// col_matrix shares the same piece of data with col,
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
// to call the matrix multiplication interface.
Tensor
col_matrix
=
col
;
col_matrix
.
Resize
(
col_matrix_shape
);
DDim
output_shape
=
{
C
,
O_H
,
O_W
};
DDim
output_shape
=
{
C
,
O_H
,
O_W
};
DDim
input_matrix_shape
=
{
M
,
H
*
W
};
DDim
input_matrix_shape
=
{
M
,
H
*
W
};
...
@@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward)
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
// input need to compute gradient
if
(
input_grad
)
{
if
(
input_grad
)
{
Tensor
col_matrix
=
col
;
DDim
col_matrix_shape
=
{
C
*
K_H
*
K_W
,
H
*
W
};
col_matrix
.
Resize
(
col_matrix_shape
);
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
input_grad
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
input_grad
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
...
@@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// batch with size (C, O_H * O_W)
// batch with size (C, O_H * O_W)
Tensor
output_grad_batch
=
Tensor
output_grad_batch
=
output_grad
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
output_shape
);
output_grad
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
output_shape
);
// filter of size (M, C * K_H * K_W)
// batch with size (M, H, W)
// batch with size (M, H, W)
Tensor
input_grad_batch
=
Tensor
input_grad_batch
=
input_grad
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
input_grad
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
// im2col: (C * K_H * K_W, H * W)
// im2col:
dy from (C, O_H, O_W) ->
(C * K_H * K_W, H * W)
im2col
(
context
.
device_context
(),
output_grad_batch
,
col_matrix
,
im2col
(
context
.
device_context
(),
output_grad_batch
,
col_matrix
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
// gemm: dx = filter * dy
// gemm: dx = filter * dy
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H)
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter
,
false
,
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
filter
,
false
,
col_matrix
,
false
,
T
(
1.0
),
&
input_grad_batch
,
col_matrix
,
false
,
T
(
1.0
),
&
input_grad_batch
,
T
(
0.0
));
T
(
0.0
));
...
@@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// filter gradient required
// filter gradient required
if
(
filter_grad
)
{
if
(
filter_grad
)
{
Tensor
col_matrix_f
=
col
;
DDim
col_matrix_shape_f
=
{
C
*
H
*
W
,
K_H
*
K_W
};
col_matrix_f
.
Resize
(
col_matrix_shape_f
);
filter_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
filter_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
Tensor
filter_grad_
=
*
filter_grad
;
Tensor
filter_grad_
=
*
filter_grad
;
filter_grad_
.
Resize
(
filter_matrix_shape
);
filter_grad_
.
Resize
(
filter_matrix_shape
);
...
@@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
...
@@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// input batch
// input batch
Tensor
in_batch
=
input
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
Tensor
in_batch
=
input
->
Slice
<
T
>
(
i
,
i
+
1
).
Resize
(
input_matrix_shape
);
// im2col: (C *
K_H * K_W, H *
W)
// im2col: (C *
H * W, K_H * K_
W)
im2col
(
context
.
device_context
(),
output_grad_batch
,
col_matrix
,
im2col
(
context
.
device_context
(),
output_grad_batch
,
col_matrix
_f
,
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
strides
[
0
],
strides
[
1
],
paddings
[
0
],
paddings
[
1
]);
// gemm: d_filter = x * y_grad^T
// gemm: d_filter = x * y_grad^T
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H)
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
in_batch
,
false
,
math
::
matmul
<
Place
,
T
>
(
context
.
device_context
(),
in_batch
,
false
,
col_matrix
,
true
,
T
(
1.0
),
&
filter_grad_
,
T
(
1.0
));
col_matrix
,
true
,
T
(
1.0
),
&
filter_grad_
,
T
(
1.0
));
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录