Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e800c0d3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e800c0d3
编写于
11月 23, 2017
作者:
C
chengduo
提交者:
GitHub
11月 23, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5791 from chengduoZH/fix_conv_op
remove vector::erase
上级
d883547b
a93227a1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
67 deletion
+51
-67
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+27
-35
paddle/operators/conv_transpose_op.h
paddle/operators/conv_transpose_op.h
+24
-32
未找到文件。
paddle/operators/conv_op.h
浏览文件 @
e800c0d3
...
@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
...
@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
std
::
vector
<
int
>&
dilations
)
{
std
::
vector
<
int
>&
dilations
)
{
bool
filter_1
=
true
,
strides_1
=
true
,
padding_0
=
true
,
dilation_1
=
true
;
bool
filter_1
=
true
,
strides_1
=
true
,
padding_0
=
true
,
dilation_1
=
true
;
for
(
size_t
j
=
0
;
j
<
strides
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
strides
.
size
();
++
j
)
{
filter_1
=
filter_1
&&
(
static_cast
<
int
>
(
filter_dim
[
j
])
==
1
);
filter_1
=
filter_1
&&
(
static_cast
<
int
>
(
filter_dim
[
j
+
2
])
==
1
);
strides_1
=
strides_1
&&
(
strides
[
j
]
==
1
);
strides_1
=
strides_1
&&
(
strides
[
j
]
==
1
);
padding_0
=
padding_0
&&
(
paddings
[
j
]
==
0
);
padding_0
=
padding_0
&&
(
paddings
[
j
]
==
0
);
dilation_1
=
dilation_1
&&
(
dilations
[
j
]
==
1
);
dilation_1
=
dilation_1
&&
(
dilations
[
j
]
==
1
);
...
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// filter_shape_vec: {k_
h, k_w} or {
k_d, k_h, k_w}
// filter_shape_vec: {k_
o, k_i, k_h, k_w} or {k_o, k_i,
k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
filter_shape_vec
.
begin
()
+
2
);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
output_shape_vec
.
erase
(
output_shape_vec
.
begin
(),
output_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col calculation
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
// o_h, o_w}
std
::
vector
<
int64_t
>
col_shape_vec
;
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
col_shape_vec
.
push_back
(
input
->
dims
()[
1
]
/
groups
);
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
col_shape_vec
[
0
]
=
input
->
dims
()[
1
]
/
groups
;
filter_shape_vec
.
end
());
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
(),
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
];
output_shape_vec
.
end
());
col_shape_vec
[
j
+
1
+
data_dim
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
// o_h * o_w)
// o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
bool
is_expand
=
IsExpand
(
filter_shape_vec
,
strides
,
paddings
,
dilations
);
Tensor
col
;
Tensor
col
;
...
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
...
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
}
else
if
(
data_dim
==
2U
)
{
// im2col
// im2col
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
col
);
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
data_dim
==
3U
)
{
// vol2col
// vol2col
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
paddings
,
&
col
);
...
@@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// filter_shape_vec: {k_
h, k_w} or {
k_d, k_h, k_w}
// filter_shape_vec: {k_
o, k_i, k_h, k_w} or {k_o, k_i,
k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
filter_shape_vec
.
begin
()
+
2
);
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
std
::
vector
<
int64_t
>
output_shape_vec
(
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output_grad
->
dims
()));
framework
::
vectorize
(
output_grad
->
dims
()));
output_shape_vec
.
erase
(
output_shape_vec
.
begin
(),
output_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col calculation
// use col_shape in the im2col calculation
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
// o_h, o_w}
// o_h, o_w}
std
::
vector
<
int64_t
>
col_shape_vec
;
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
col_shape_vec
.
push_back
(
input
->
dims
()[
1
]
/
groups
);
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
col_shape_vec
[
0
]
=
input
->
dims
()[
1
]
/
groups
;
filter_shape_vec
.
end
());
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
output_shape_vec
.
begin
(),
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
];
output_shape_vec
.
end
());
col_shape_vec
[
j
+
1
+
data_dim
]
=
output_shape_vec
[
j
+
2
];
}
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
framework
::
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
...
@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// or
// or
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
framework
::
DDim
col_matrix_shape
=
framework
::
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
framework
::
DDim
input_shape
=
framework
::
slice_ddim
(
input
->
dims
(),
1
,
static_cast
<
int
>
(
input
->
dims
().
size
()));
input
->
dims
(),
1
,
static_cast
<
int
>
(
input
->
dims
().
size
()));
...
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
T
(
0.0
));
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
2
)
{
if
(
is_expand
&&
data_dim
==
2U
)
{
col2im
(
context
.
device_context
(),
col
,
dilations
,
strides
,
col2im
(
context
.
device_context
(),
col
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
in_grad_slice
);
&
in_grad_slice
);
}
else
if
(
is_expand
&&
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
is_expand
&&
data_dim
==
3U
)
{
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
&
in_grad_slice
);
&
in_grad_slice
);
}
}
...
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
...
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col
.
ShareDataWith
(
in_slice
);
col
.
ShareDataWith
(
in_slice
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
ShareDataWith
(
col
);
col_matrix
.
Resize
(
col_matrix_shape
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
else
if
(
filter_shape_vec
.
size
()
==
2
)
{
}
else
if
(
data_dim
==
2U
)
{
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
im2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
col
);
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
data_dim
==
3U
)
{
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
vol2col
(
context
.
device_context
(),
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
paddings
,
&
col
);
}
}
...
...
paddle/operators/conv_transpose_op.h
浏览文件 @
e800c0d3
...
@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {
h, w} or {
d, h, w}
// input_shape_vec: {
n, c, h, w} or {n, c,
d, h, w}
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
input_shape_vec
.
erase
(
input_shape_vec
.
begin
(),
input_shape_vec
.
begin
()
+
2
);
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std
::
vector
<
int64_t
>
col_shape_vec
;
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
col_shape_vec
.
push_back
(
output
->
dims
()[
1
]);
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
col_shape_vec
[
0
]
=
output
->
dims
()[
1
];
filter_shape_vec
.
end
());
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
(),
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
];
input_shape_vec
.
end
());
col_shape_vec
[
j
+
1
+
data_dim
]
=
input_shape_vec
[
j
+
2
];
}
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
Tensor
col
;
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
...
@@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -136,7 +132,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
input_batch
,
false
,
static_cast
<
T
>
(
1.0
),
input_batch
,
false
,
static_cast
<
T
>
(
1.0
),
&
col_matrix
,
static_cast
<
T
>
(
0.0
));
&
col_matrix
,
static_cast
<
T
>
(
0.0
));
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
data_dim
==
2U
)
{
// col2im: col_matrix -> dy
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im
(
context
.
device_context
(),
col
,
col2im
(
context
.
device_context
(),
col
,
...
@@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
...
@@ -144,7 +140,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
output_batch
);
&
output_batch
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
data_dim
==
3U
)
{
// col2vol: col_matrix -> dy
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
col2vol
(
context
.
device_context
(),
col
,
dilations
,
strides
,
paddings
,
...
@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
// input_shape_vec: {
h, w} or {
d, h, w}
// input_shape_vec: {
n, c, h, w} or {n, c,
d, h, w}
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
std
::
vector
<
int64_t
>
input_shape_vec
=
framework
::
vectorize
(
input
->
dims
());
input_shape_vec
.
erase
(
input_shape_vec
.
begin
(),
input_shape_vec
.
begin
()
+
2
);
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
std
::
vector
<
int64_t
>
filter_shape_vec
=
framework
::
vectorize
(
filter
.
dims
());
filter_shape_vec
.
erase
(
filter_shape_vec
.
begin
(),
filter_shape_vec
.
begin
()
+
2
);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std
::
vector
<
int64_t
>
col_shape_vec
;
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
col_shape_vec
.
push_back
(
output_grad
->
dims
()[
1
]);
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
filter_shape_vec
.
begin
(),
col_shape_vec
[
0
]
=
output_grad
->
dims
()[
1
];
filter_shape_vec
.
end
());
for
(
size_t
j
=
0
;
j
<
data_dim
;
++
j
)
{
col_shape_vec
.
insert
(
col_shape_vec
.
end
(),
input_shape_vec
.
begin
(),
col_shape_vec
[
j
+
1
]
=
filter_shape_vec
[
j
+
2
];
input_shape_vec
.
end
());
col_shape_vec
[
j
+
1
+
data_dim
]
=
input_shape_vec
[
j
+
2
];
}
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
DDim
col_shape
(
framework
::
make_ddim
(
col_shape_vec
));
// use col_matrix_shape in the gemm calculation
// use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim
col_matrix_shape
=
DDim
col_matrix_shape
=
framework
::
flatten_to_2d
(
col_shape
,
data_dim
+
1
);
framework
::
flatten_to_2d
(
col_shape
,
filter_shape_vec
.
size
()
+
1
);
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim
output_shape
=
framework
::
slice_ddim
(
output_grad
->
dims
(),
1
,
DDim
output_shape
=
framework
::
slice_ddim
(
output_grad
->
dims
(),
1
,
...
@@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -248,7 +240,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor
output_grad_batch
=
Tensor
output_grad_batch
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_shape
);
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
(
output_shape
);
if
(
filter_shape_vec
.
size
()
==
2
)
{
if
(
data_dim
==
2U
)
{
// im2col: dy -> col matrix
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col
(
context
.
device_context
(),
output_grad_batch
,
im2col
(
context
.
device_context
(),
output_grad_batch
,
...
@@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
...
@@ -256,7 +248,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
col
);
&
col
);
}
else
if
(
filter_shape_vec
.
size
()
==
3
)
{
}
else
if
(
data_dim
==
3U
)
{
// vol2col: dy -> col_matrix
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col
(
context
.
device_context
(),
output_grad_batch
,
dilations
,
vol2col
(
context
.
device_context
(),
output_grad_batch
,
dilations
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录