Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e5bf9c56
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5bf9c56
编写于
11月 21, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove vector::eraze
上级
e930f496
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
57 deletion
+43
-57
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+23
-31
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+20
-26
未找到文件。
paddle/operators/conv_op.h
浏览文件 @
e5bf9c56
...
...
@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
std
::
vector
<
int
>&
dilations
)
{
bool
filter_1
=
true
,
strides_1
=
true
,
padding_0
=
true
,
dilation_1
=
true
;
for
(
size_t
j
=
0
;
j
<
strides
.
size
();
++
j
)
{
filter_1
=
filter_1
&&
(
static_cast
<
int
>
(
filter_dim
[
j
])
==
1
);
filter_1
=
filter_1
&&
(
static_cast
<
int
>
(
filter_dim
[
j
+
2
])
==
1
);
strides_1
=
strides_1
&&
(
strides
[
j
]
==
1
);
padding_0
=
padding_0
&&
(
paddings
[
j
]
==
0
);
dilation_1
=
dilation_1
&&
(
dilations
[
j
]
==
1
);
...
...
@@ -91,24 +91,20 @@ class GemmConvKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// filter_shape_vec: {k_
h, k_w} or {
k_d, k_h, k_w}
// filter_shape_vec: {k_
o, k_i, k_h, k_w} or {k_o, k_i,
k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
output_shape_vec
.
erase
(
output_shape_vec
.
begin
(),
output_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
std
::
vector
<
int64_t
>
col_shape_vec
;
col_shape_vec
.
push_back
(
input
->
dims
()[
1
]
/
groups
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
std
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
output_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
input
->
dims
()[
1
]
/
groups
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
());
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
(),
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
()
+
2
,
output_shape_vec
.
end
());
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
...
...
@@ -116,7 +112,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
...
...
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
4
)
{
// im2col
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
// vol2col
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
...
...
@@ -206,25 +202,21 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// filter_shape_vec: {k_
h, k_w} or {
k_d, k_h, k_w}
// filter_shape_vec: {k_
o, k_i, k_h, k_w} or {k_o, k_i,
k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output_grad
->
dims
()));
output_shape_vec
.
erase
(
output_shape_vec
.
begin
(),
output_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
std
::
vector
<
int64_t
>
col_shape_vec
;
col_shape_vec
.
push_back
(
input
->
dims
()[
1
]
/
groups
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
std
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
output_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
input
->
dims
()[
1
]
/
groups
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
());
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
(),
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
()
+
2
,
output_shape_vec
.
end
());
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
...
...
@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// or
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
static_cast
<
int
>
(
input
->
dims
().
size
()));
...
...
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
2
)
{
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
4
)
{
col2im
(
context
.
device_context
(),
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
in_grad_slice
);
}
else
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
5
)
{
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
&
in_grad_slice
);
}
...
...
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
4
)
{
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
}
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
e5bf9c56
...
...
@@ -68,30 +68,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {
h, w} or {
d, h, w}
// input_shape_vec: {
n, c, h, w} or {n, c,
d, h, w}
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
input_shape_vec
.
erase
(
input_shape_vec
.
begin
(),
input_shape_vec
.
begin
()
+
2
);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std
::
vector
<
int64_t
>
col_shape_vec
;
col_shape_vec
.
push_back
(
output
->
dims
()[
1
]);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
std
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
input_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
output
->
dims
()[
1
]);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
());
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
(),
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
()
+
2
,
input_shape_vec
.
end
());
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
...
...
@@ -136,7 +133,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
input_batch
,
false
,
static_cast
<
T
>
(
1.0
),
&
col_matrix
,
static_cast
<
T
>
(
0.0
));
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
4
)
{
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im
(
context
.
device_context
(),
col
,
...
...
@@ -144,7 +141,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
output_batch
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
...
...
@@ -176,30 +173,27 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {
h, w} or {
d, h, w}
// input_shape_vec: {
n, c, h, w} or {n, c,
d, h, w}
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
input_shape_vec
.
erase
(
input_shape_vec
.
begin
(),
input_shape_vec
.
begin
()
+
2
);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std
::
vector
<
int64_t
>
col_shape_vec
;
col_shape_vec
.
push_back
(
output_grad
->
dims
()[
1
]);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
std
::
vector
<
int64_t
>
col_shape_vec
(
filter_shape_vec
.
size
()
+
input_shape_vec
.
size
()
-
3
);
col_shape_vec
.
assign
(
1
,
output_grad
->
dims
()[
1
]);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
()
+
2
,
filter_shape_vec
.
end
());
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
(),
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
()
+
2
,
input_shape_vec
.
end
());
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
-
2
+
1
);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim
output_shape
=
framework
::
slice_ddim
(
output_grad
->
dims
(),
1
,
...
...
@@ -248,7 +242,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor
output_grad_batch
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_shape
);
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
filter_shape_vec
.
size
()
==
4
)
{
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col
(
context
.
device_context
(),
output_grad_batch
,
...
...
@@ -256,7 +250,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
filter_shape_vec
.
size
()
==
5
)
{
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col
(
context
.
device_context
(),
output_grad_batch
,
dilations
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录