Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a93227a1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a93227a1
编写于
11月 22, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code
上级
e5bf9c56
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
42 addition
and
44 deletion
+42
-44
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+22
-22
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+20
-22
未找到文件。
paddle/operators/conv_op.h
浏览文件 @
a93227a1
...
...
@@ -99,20 +99,20 @@ class GemmConvKernel : public framework::OpKernel<T> {
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
s
td
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
output_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
input
->
dims
()[
1
]
/
groups
)
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
())
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
()
+
2
,
output_shape_vec
.
end
());
s
ize_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
[
0
]
=
input
->
dims
()[
1
]
/
groups
;
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
]
;
col_shape_vec
[
j
+
1
+
data_dim
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
...
...
@@ -155,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
4
)
{
}
else
if
(
data_dim
==
2U
)
{
// im2col
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
}
else
if
(
data_dim
==
3U
)
{
// vol2col
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
...
...
@@ -211,13 +211,13 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
s
td
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
output_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
input
->
dims
()[
1
]
/
groups
)
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
())
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
()
+
2
,
output_shape_vec
.
end
());
s
ize_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
[
0
]
=
input
->
dims
()[
1
]
/
groups
;
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
]
;
col_shape_vec
[
j
+
1
+
data_dim
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
...
...
@@ -225,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// or
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
static_cast
<
int
>
(
input
->
dims
().
size
()));
...
...
@@ -286,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
4
)
{
if
(
is_expand
&&
data_dim
==
2U
)
{
col2im
(
context
.
device_context
(),
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
in_grad_slice
);
}
else
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
5
)
{
}
else
if
(
is_expand
&&
data_dim
==
3U
)
{
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
&
in_grad_slice
);
}
...
...
@@ -320,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
4
)
{
}
else
if
(
data_dim
==
2U
)
{
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
}
else
if
(
data_dim
==
3U
)
{
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
a93227a1
...
...
@@ -76,19 +76,18 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
s
td
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
input_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
output
->
dims
()[
1
])
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
())
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
()
+
2
,
input_shape_vec
.
end
());
s
ize_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
[
0
]
=
output
->
dims
()[
1
]
;
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
]
;
col_shape_vec
[
j
+
1
+
data_dim
]
=
input_shape_vec
[
j
+
2
];
}
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
...
...
@@ -133,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
input_batch
,
false
,
static_cast
<
T
>
(
1.0
),
&
col_matrix
,
static_cast
<
T
>
(
0.0
));
if
(
filter_shape_vec
.
size
()
==
4
)
{
if
(
data_dim
==
2U
)
{
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im
(
context
.
device_context
(),
col
,
...
...
@@ -141,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
output_batch
);
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
}
else
if
(
data_dim
==
3U
)
{
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
...
...
@@ -181,19 +180,18 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
s
td
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
input_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
output_grad
->
dims
()[
1
])
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
())
;
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
()
+
2
,
input_shape_vec
.
end
());
s
ize_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
[
0
]
=
output_grad
->
dims
()[
1
]
;
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
]
;
col_shape_vec
[
j
+
1
+
data_dim
]
=
input_shape_vec
[
j
+
2
];
}
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim
output_shape
=
framework
::
slice_ddim
(
output_grad
->
dims
(),
1
,
...
...
@@ -242,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor
output_grad_batch
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_shape
);
if
(
filter_shape_vec
.
size
()
==
4
)
{
if
(
data_dim
==
2U
)
{
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col
(
context
.
device_context
(),
output_grad_batch
,
...
...
@@ -250,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
}
else
if
(
data_dim
==
3U
)
{
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col
(
context
.
device_context
(),
output_grad_batch
,
dilations
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录