Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cf6919bf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf6919bf
编写于
10月 07, 2019
作者:
Z
Zhang Ting
提交者:
hong
10月 07, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
conv_transpose supports channel_last input, test=develop, test=document_preview (#20072)
上级
c9139c3d
变更
17
展开全部
隐藏空白更改
内联
并排
Showing
17 changed file
with
2789 addition
and
435 deletion
+2789
-435
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
paddle/fluid/operators/conv_transpose_cudnn_op.cu
paddle/fluid/operators/conv_transpose_cudnn_op.cu
+586
-0
paddle/fluid/operators/conv_transpose_op.cc
paddle/fluid/operators/conv_transpose_op.cc
+95
-61
paddle/fluid/operators/conv_transpose_op.cu
paddle/fluid/operators/conv_transpose_op.cu
+0
-0
paddle/fluid/operators/conv_transpose_op.h
paddle/fluid/operators/conv_transpose_op.h
+405
-75
paddle/fluid/operators/math/depthwise_conv.cu
paddle/fluid/operators/math/depthwise_conv.cu
+197
-77
paddle/fluid/operators/math/depthwise_conv.h
paddle/fluid/operators/math/depthwise_conv.h
+8
-3
paddle/fluid/operators/math/im2col.cc
paddle/fluid/operators/math/im2col.cc
+26
-11
paddle/fluid/operators/math/im2col.cu
paddle/fluid/operators/math/im2col.cu
+73
-32
paddle/fluid/operators/math/im2col.h
paddle/fluid/operators/math/im2col.h
+6
-2
paddle/fluid/operators/math/im2col_cfo_cpu.h
paddle/fluid/operators/math/im2col_cfo_cpu.h
+87
-22
paddle/fluid/operators/math/vol2col.cc
paddle/fluid/operators/math/vol2col.cc
+44
-21
paddle/fluid/operators/math/vol2col.cu
paddle/fluid/operators/math/vol2col.cu
+68
-30
paddle/fluid/operators/math/vol2col.h
paddle/fluid/operators/math/vol2col.h
+7
-4
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+208
-71
python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py
.../paddle/fluid/tests/unittests/test_conv2d_transpose_op.py
+551
-19
python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py
.../paddle/fluid/tests/unittests/test_conv3d_transpose_op.py
+426
-5
未找到文件。
paddle/fluid/API.spec
浏览文件 @
cf6919bf
...
@@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment
...
@@ -153,8 +153,8 @@ paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'moment
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '02972097e089629efdb0ed9404fd36ae'))
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', '02972097e089629efdb0ed9404fd36ae'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'
], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'ab58296b567bf0c686084add7f3280a4
'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'
, 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCHW')), ('document', '9391d75358b6cba0cc5d22a01a223420
'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'
], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', 'fe15dbfb17d97d3d29b2fa7ee6390ee6
'))
paddle.fluid.layers.conv3d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'
, 'data_format'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None, 'NCDHW')), ('document', '74bce3cd4224e6ff133d54508dc7f150
'))
paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0'))
paddle.fluid.layers.sequence_expand (ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '10e122eb755c2bd1f78ef2332b28f1a0'))
paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50'))
paddle.fluid.layers.sequence_expand_as (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '858c432e7cbd8bb952cc2eb555457d50'))
paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62'))
paddle.fluid.layers.sequence_pad (ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'df08b9c499ab3a90f95d08ab5b6c6c62'))
...
...
paddle/fluid/operators/conv_transpose_cudnn_op.cu
.cc
→
paddle/fluid/operators/conv_transpose_cudnn_op.cu
浏览文件 @
cf6919bf
此差异已折叠。
点击以展开。
paddle/fluid/operators/conv_transpose_op.cc
浏览文件 @
cf6919bf
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
...
@@ -25,13 +26,15 @@ limitations under the License. */
...
@@ -25,13 +26,15 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
DataLayout
=
framework
::
DataLayout
;
void
ConvTransposeOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
ConvTransposeOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
)
,
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
"Input(Input) of ConvTransposeOp should not be null."
);
"Input(Input) of ConvTransposeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
)
,
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInput
(
"Filter"
),
true
,
"Input(Filter) of ConvTransposeOp should not be null."
);
"Input(Filter) of ConvTransposeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
)
,
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutput
(
"Output"
),
true
,
"Output(Output) of ConvTransposeOp should not be null."
);
"Output(Output) of ConvTransposeOp should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
...
@@ -41,52 +44,75 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -41,52 +44,75 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
int
groups
=
ctx
->
Attrs
().
Get
<
int
>
(
"groups"
);
std
::
string
padding_algorithm
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"padding_algorithm"
);
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_format"
));
PADDLE_ENFORCE
(
in_dims
.
size
()
==
4
||
in_dims
.
size
()
==
5
,
PADDLE_ENFORCE
_EQ
(
in_dims
.
size
()
==
4
||
in_dims
.
size
()
==
5
,
true
,
"ConvTransposeOp intput should be 4-D or 5-D tensor."
);
"ConvTransposeOp intput should be 4-D or 5-D tensor."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
filter_dims
.
size
(),
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
filter_dims
.
size
(),
"ConvTransposeOp input dimension and filter dimension "
"ConvTransposeOp input dimension and filter dimension "
"should be the same."
);
"should be the same."
);
PADDLE_ENFORCE
(
in_dims
.
size
()
-
strides
.
size
()
==
2U
,
PADDLE_ENFORCE_EQ
(
"ConvTransposeOp input dimension and strides dimension should "
in_dims
.
size
()
-
strides
.
size
(),
2U
,
"be consistent."
);
"ConvTransposeOp input dimension and strides dimension should "
"be consistent."
);
if
(
output_size
.
size
())
if
(
output_size
.
size
())
PADDLE_ENFORCE_EQ
(
output_size
.
size
(),
strides
.
size
(),
PADDLE_ENFORCE_EQ
(
output_size
.
size
(),
strides
.
size
(),
"ConvTransposeOp output_size dimension and strides "
"ConvTransposeOp output_size dimension and strides "
"dimension should be the same."
);
"dimension should be the same."
);
PADDLE_ENFORCE_EQ
(
paddings
.
size
(),
strides
.
size
(),
"ConvTransposeOp paddings dimension and strides "
const
int64_t
C
=
"dimension should be the same."
);
(
data_layout
==
DataLayout
::
kNCHW
?
in_dims
[
1
]
PADDLE_ENFORCE_EQ
(
paddings
.
size
(),
dilations
.
size
(),
:
in_dims
[
in_dims
.
size
()
-
1
]);
"ConvTransposeOp paddings dimension and dilations "
PADDLE_ENFORCE_EQ
(
"dimension should be the same."
);
C
,
filter_dims
[
0
],
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
filter_dims
[
0
],
"The number of input channels of Op(ConvTransposeOp) should "
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels."
);
"be equal to the number of filter's channels."
);
framework
::
DDim
in_data_dims
;
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
],
filter_dims
[
1
]
*
groups
});
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
in_data_dims
=
framework
::
slice_ddim
(
in_dims
,
2
,
in_dims
.
size
());
}
else
{
in_data_dims
=
framework
::
slice_ddim
(
in_dims
,
1
,
in_dims
.
size
()
-
1
);
}
framework
::
DDim
filter_data_dims
=
framework
::
slice_ddim
(
filter_dims
,
2
,
filter_dims
.
size
());
std
::
vector
<
int
>
ksize
=
framework
::
vectorize
<
int
>
(
filter_data_dims
);
UpdatePaddingAndDilation
(
&
paddings
,
&
dilations
,
padding_algorithm
,
in_data_dims
,
strides
,
ksize
);
std
::
vector
<
int64_t
>
output_shape
({
in_dims
[
0
]});
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
output_shape
.
push_back
(
filter_dims
[
1
]
*
groups
);
}
const
int
offset
=
(
data_layout
==
DataLayout
::
kNCHW
?
2
:
1
);
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
auto
filter_extent
=
dilations
[
i
]
*
(
filter_dims
[
i
+
2
]
-
1
)
+
1
;
auto
filter_extent
=
dilations
[
i
]
*
(
filter_dims
[
i
+
2
]
-
1
)
+
1
;
auto
infer_shape
=
auto
infer_shape
=
(
in_dims
[
i
+
offset
]
-
1
)
*
strides
[
i
]
-
(
in_dims
[
i
+
2
]
-
1
)
*
strides
[
i
]
-
2
*
paddings
[
i
]
+
filter_extent
;
paddings
[
2
*
i
]
-
paddings
[
2
*
i
+
1
]
+
filter_extent
;
if
(
output_size
.
size
())
{
if
(
output_size
.
size
())
{
PADDLE_ENFORCE
((
output_size
[
i
]
>=
infer_shape
&&
PADDLE_ENFORCE_EQ
((
output_size
[
i
]
>=
infer_shape
&&
output_size
[
i
]
<
infer_shape
+
strides
[
i
]),
output_size
[
i
]
<
infer_shape
+
strides
[
i
]),
"ConvTransposeOp output_size should be "
true
,
"in appropriate range."
);
"output_size of Op(ConvTransposeOp) should be "
"in appropriate range."
);
output_shape
.
push_back
(
output_size
[
i
]);
output_shape
.
push_back
(
output_size
[
i
]);
}
else
{
}
else
{
output_shape
.
push_back
(
infer_shape
);
output_shape
.
push_back
(
infer_shape
);
}
}
}
}
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
output_shape
.
push_back
(
filter_dims
[
1
]
*
groups
);
}
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
}
framework
::
OpKernelType
ConvTransposeOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
ConvTransposeOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
use_cudnn
&=
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
use_cudnn
&=
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -115,12 +141,11 @@ void Conv2DTransposeOpMaker::Make() {
...
@@ -115,12 +141,11 @@ void Conv2DTransposeOpMaker::Make() {
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddInput
(
AddInput
(
"Input"
,
"Input"
,
"(Tensor) The input tensor of convolution transpose operator. "
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW or NHWC. Where N is batch size, "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"C is the number of input channels, H is the height of the feature, "
"number of input channels, H is the height of the feature, and "
"and W is the width of the feature."
);
"W is the width of the feature."
);
AddInput
(
AddInput
(
"Filter"
,
"Filter"
,
"(Tensor) The filter tensor of convolution transpose operator. "
"(Tensor) The filter tensor of convolution transpose operator. "
...
@@ -137,7 +162,7 @@ void Conv2DTransposeOpMaker::Make() {
...
@@ -137,7 +162,7 @@ void Conv2DTransposeOpMaker::Make() {
AddOutput
(
"Output"
,
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution transpose operator. "
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is
also NCHW
."
);
"The format of output tensor is
the same as input tensor
."
);
AddAttr
<
std
::
vector
<
int
>>
(
"output_size"
,
AddAttr
<
std
::
vector
<
int
>>
(
"output_size"
,
"(vector<int> default: []), the "
"(vector<int> default: []), the "
"size of the output tensor"
)
"size of the output tensor"
)
...
@@ -182,10 +207,15 @@ void Conv2DTransposeOpMaker::Make() {
...
@@ -182,10 +207,15 @@ void Conv2DTransposeOpMaker::Make() {
"data_format"
,
"data_format"
,
"(string, default NCHW) Only used in "
"(string, default NCHW) Only used in "
"An optional string from:
\"
NHWC
\"
,
\"
NCHW
\"
. "
"An optional string from:
\"
NHWC
\"
,
\"
NCHW
\"
. "
"Defaults to
\"
NHWC
\"
. Specify the data format of the output data, "
"Specify that the data format of the input and output data is "
"the input will be transformed automatically. "
)
"channel_first or channel_last."
)
.
SetDefault
(
"AnyLayout"
);
.
SetDefault
(
"NCHW"
);
// TODO(dzhwinter): need to registered layout transform function
AddAttr
<
std
::
string
>
(
"padding_algorithm"
,
"(string, default
\"
EXPLICIT
\"
) An optional string from:
\"
EXPLICIT
\"
,"
"
\"
SAME
\"
,
\"
VALID
\"
. Set to
\"
EXPLICIT
\"
for explicit padding. "
"Set to
\"
SAME
\"
or
\"
VALID
\"
for algorithm of padding. "
)
.
SetDefault
(
"EXPLICIT"
);
AddAttr
<
int
>
(
"workspace_size_MB"
,
AddAttr
<
int
>
(
"workspace_size_MB"
,
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"workspace is a section of GPU memory which will be "
...
@@ -199,7 +229,7 @@ Convolution2D Transpose Operator.
...
@@ -199,7 +229,7 @@ Convolution2D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
Input(Input) and output(Output) are in NCHW
or NHWC
format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
number of channels, H is the height of the feature, and W is the width of the feature.
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
C is the number of output feature channels, H is the height of the filter,
C is the number of output feature channels, H is the height of the filter,
...
@@ -216,19 +246,19 @@ For an example:
...
@@ -216,19 +246,19 @@ For an example:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
Where
$$
$$
H_{out} = (H_{in} - 1) * strides[0] -
2 * paddings[0]
+ dilations[0] * (H_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[0] -
pad_height_top - pad_height_bottom
+ dilations[0] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[1] -
2 * paddings[1]
+ dilations[1] * (W_f - 1) + 1
W_{out} = (W_{in} - 1) * strides[1] -
pad_width_left - pad_width_right
+ dilations[1] * (W_f - 1) + 1
$$
$$
)DOC"
);
)DOC"
);
}
}
void
Conv3DTransposeOpMaker
::
Make
()
{
void
Conv3DTransposeOpMaker
::
Make
()
{
AddInput
(
"Input"
,
AddInput
(
"(Tensor) The input tensor of convolution transpose operator."
"Input"
,
"The format of input tensor is NCDHW. Where N is batch size, C is
"
"(Tensor) The input tensor of convolution transpose operator.
"
"the number of channels, D is the depth of the feature, H is the
"
"The format of input tensor is NCDHW or NDHWC. Where N is batch
"
"height of the feature, and
"
"size, C is the number of channels, D is the depth of the feature,
"
"
W is the width of the feature."
);
"H is the height of the feature, and
W is the width of the feature."
);
AddInput
(
"Filter"
,
AddInput
(
"Filter"
,
"(Tensor) The filter tensor of convolution transpose operator."
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is MCDHW, where M is the number of "
"The format of the filter tensor is MCDHW, where M is the number of "
...
@@ -240,7 +270,7 @@ void Conv3DTransposeOpMaker::Make() {
...
@@ -240,7 +270,7 @@ void Conv3DTransposeOpMaker::Make() {
"the convolution3d transpose scenario."
);
"the convolution3d transpose scenario."
);
AddOutput
(
"Output"
,
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution transpose operator."
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is
also NCDHW
."
"The format of output tensor is
the same as input tensor
."
"Where N is batch size, C is "
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature."
);
"height of the feature, and W is the width of the feature."
);
...
@@ -278,10 +308,15 @@ void Conv3DTransposeOpMaker::Make() {
...
@@ -278,10 +308,15 @@ void Conv3DTransposeOpMaker::Make() {
"data_format"
,
"data_format"
,
"(string, default NCHW) Only used in "
"(string, default NCHW) Only used in "
"An optional string from:
\"
NHWC
\"
,
\"
NCHW
\"
. "
"An optional string from:
\"
NHWC
\"
,
\"
NCHW
\"
. "
"Defaults to
\"
NHWC
\"
. Specify the data format of the output data, "
"Specify that the data format of the input and output data is "
"the input will be transformed automatically. "
)
"channel_first or channel_last."
)
.
SetDefault
(
"AnyLayout"
);
.
SetDefault
(
"NCHW"
);
// TODO(dzhwinter): need to registered layout transform function
AddAttr
<
std
::
string
>
(
"padding_algorithm"
,
"(string, default
\"
EXPLICIT
\"
) An optional string from:
\"
EXPLICIT
\"
,"
"
\"
SAME
\"
,
\"
VALID
\"
. Set to
\"
EXPLICIT
\"
for explicit padding. "
"Set to
\"
SAME
\"
or
\"
VALID
\"
for algorithm of padding. "
)
.
SetDefault
(
"EXPLICIT"
);
AddAttr
<
int
>
(
"workspace_size_MB"
,
AddAttr
<
int
>
(
"workspace_size_MB"
,
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"workspace is a section of GPU memory which will be "
...
@@ -295,7 +330,7 @@ Convolution3D Transpose Operator.
...
@@ -295,7 +330,7 @@ Convolution3D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
Input(Input) and output(Output) are in NCDHW
or NDHWC
format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
number of channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature.
and W is the width of the feature.
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
...
@@ -313,9 +348,9 @@ Example:
...
@@ -313,9 +348,9 @@ Example:
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Where
Where
$$
$$
D_{out} = (D_{in} - 1) * strides[0] -
2 * paddings[0]
+ dilations[0] * (D_f - 1) + 1 \\
D_{out} = (D_{in} - 1) * strides[0] -
pad_depth_front - pad_depth_back
+ dilations[0] * (D_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[1] -
2 * paddings[1]
+ dilations[1] * (H_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[1] -
pad_height_top - pad_height_bottom
+ dilations[1] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[2] -
2 * paddings[2]
+ dilations[2] * (W_f - 1) + 1
W_{out} = (W_{in} - 1) * strides[2] -
pad_width_left - pad_width_right
+ dilations[2] * (W_f - 1) + 1
$$
$$
)DOC"
);
)DOC"
);
}
}
...
@@ -348,8 +383,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
...
@@ -348,8 +383,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
library_
=
framework
::
LibraryType
::
kPlain
;
library_
=
framework
::
LibraryType
::
kPlain
;
}
}
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
GetPlace
(),
layout_
,
library_
);
ctx
.
GetPlace
(),
layout_
,
library_
);
}
}
...
...
paddle/fluid/operators/conv_transpose_op.cu
.cc
→
paddle/fluid/operators/conv_transpose_op.cu
浏览文件 @
cf6919bf
文件已移动
paddle/fluid/operators/conv_transpose_op.h
浏览文件 @
cf6919bf
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/depthwise_conv.cu
浏览文件 @
cf6919bf
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/depthwise_conv.h
浏览文件 @
cf6919bf
...
@@ -22,6 +22,8 @@ namespace paddle {
...
@@ -22,6 +22,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
using
DataLayout
=
framework
::
DataLayout
;
/*
/*
* \brief Compute the depthwise convolution which include
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
* forward process and backpropagation process
...
@@ -34,7 +36,8 @@ class DepthwiseConvFunctor {
...
@@ -34,7 +36,8 @@ class DepthwiseConvFunctor {
const
framework
::
Tensor
&
filter
,
const
framework
::
Tensor
&
filter
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
framework
::
Tensor
*
output
);
const
std
::
vector
<
int
>&
dilations
,
framework
::
Tensor
*
output
,
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
);
};
};
template
<
typename
DeviceContext
,
typename
T
,
template
<
typename
DeviceContext
,
typename
T
,
...
@@ -47,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
...
@@ -47,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
framework
::
Tensor
*
input_grad
);
framework
::
Tensor
*
input_grad
,
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
);
};
};
template
<
typename
DeviceContext
,
typename
T
,
template
<
typename
DeviceContext
,
typename
T
,
...
@@ -59,7 +63,8 @@ class DepthwiseConvFilterGradFunctor {
...
@@ -59,7 +63,8 @@ class DepthwiseConvFilterGradFunctor {
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
framework
::
Tensor
*
filter_grad
);
framework
::
Tensor
*
filter_grad
,
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/im2col.cc
浏览文件 @
cf6919bf
...
@@ -32,7 +32,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -32,7 +32,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
"The dimension of col should be 5."
);
...
@@ -41,16 +42,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -41,16 +42,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
dilation
[
1
]
==
1
)
{
dilation
[
1
]
==
1
)
{
if
(
padding
[
0
]
==
0
&&
padding
[
1
]
==
0
&&
padding
[
2
]
==
0
&&
if
(
padding
[
0
]
==
0
&&
padding
[
1
]
==
0
&&
padding
[
2
]
==
0
&&
padding
[
3
]
==
0
)
{
padding
[
3
]
==
0
)
{
im2col_sh1sw1dh1dw1ph0pw0
<
T
>
(
im
,
col
);
im2col_sh1sw1dh1dw1ph0pw0
<
T
>
(
im
,
col
,
data_layout
);
return
;
return
;
}
else
if
(
padding
[
0
]
==
1
&&
padding
[
1
]
==
1
&&
padding
[
2
]
==
1
&&
}
else
if
(
padding
[
0
]
==
1
&&
padding
[
1
]
==
1
&&
padding
[
2
]
==
1
&&
padding
[
3
]
==
1
)
{
padding
[
3
]
==
1
)
{
im2col_sh1sw1dh1dw1ph1pw1
<
T
>
(
im
,
col
);
im2col_sh1sw1dh1dw1ph1pw1
<
T
>
(
im
,
col
,
data_layout
);
return
;
return
;
}
}
// TODO(TJ): complete padding >=2
// TODO(TJ): complete padding >=2
}
}
im2col_common
<
T
>
(
im
,
dilation
,
stride
,
padding
,
col
);
im2col_common
<
T
>
(
im
,
dilation
,
stride
,
padding
,
col
,
data_layout
);
}
}
};
};
...
@@ -67,13 +68,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -67,13 +68,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
"The dimension of col should be 5."
);
int
im_channels
=
im
->
dims
()[
0
];
int
im_channels
=
int
im_height
=
im
->
dims
()[
1
];
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
0
]
:
im
->
dims
()[
2
]);
int
im_width
=
im
->
dims
()[
2
];
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
1
]
:
im
->
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
2
]
:
im
->
dims
()[
1
]);
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
col_height
=
col
.
dims
()[
3
];
int
col_height
=
col
.
dims
()[
3
];
...
@@ -109,7 +114,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -109,7 +114,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
im_height
&&
if
((
im_row_idx
)
>=
0
&&
(
im_row_idx
)
<
im_height
&&
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
im_width
)
{
(
im_col_idx
)
>=
0
&&
(
im_col_idx
)
<
im_width
)
{
im_data
[(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
]
+=
int
im_offset
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
im_offset
=
(
c_im
*
im_height
+
im_row_idx
)
*
im_width
+
im_col_idx
;
}
else
{
im_offset
=
(
im_row_idx
*
im_width
+
im_col_idx
)
*
im_channels
+
c_im
;
}
im_data
[
im_offset
]
+=
col_data
[(
c
*
col_height
+
h
)
*
col_width
+
w
];
col_data
[(
c
*
col_height
+
h
)
*
col_width
+
w
];
}
}
}
}
...
@@ -139,7 +152,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -139,7 +152,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
"The dimension of col should be 5."
);
...
@@ -202,7 +216,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -202,7 +216,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
"The dimension of col should be 5."
);
...
...
paddle/fluid/operators/math/im2col.cu
浏览文件 @
cf6919bf
...
@@ -26,27 +26,41 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
...
@@ -26,27 +26,41 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
int
im_width
,
int
dilation_h
,
int
dilation_w
,
int
im_width
,
int
dilation_h
,
int
dilation_w
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
col_height
,
int
col_width
,
T
*
data_col
)
{
int
col_height
,
int
col_width
,
T
*
data_col
,
const
DataLayout
data_layout
)
{
int
input_channels
=
num_outs
/
col_height
/
col_width
;
int
channels_col
=
input_channels
*
filter_height
*
filter_width
;
const
int
index
=
const
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
num_outs
)
{
if
(
index
<
num_outs
)
{
int
w_out
=
index
%
col_width
;
int
w_out
=
(
data_layout
==
DataLayout
::
kNCHW
int
h_out
=
(
index
/
col_width
)
%
col_height
;
?
index
%
col_width
int
channel_in
=
index
/
col_width
/
col_height
;
:
(
index
/
input_channels
)
%
col_width
);
int
h_out
=
(
data_layout
==
DataLayout
::
kNCHW
?
(
index
/
col_width
)
%
col_height
:
(
index
/
input_channels
/
col_width
)
%
col_height
);
int
channel_in
=
(
data_layout
==
DataLayout
::
kNCHW
?
index
/
col_width
/
col_height
:
index
%
input_channels
);
int
channel_out
=
channel_in
*
filter_height
*
filter_width
;
int
channel_out
=
channel_in
*
filter_height
*
filter_width
;
int
h_in
=
h_out
*
stride_height
-
padding_height
;
int
h_in
=
h_out
*
stride_height
-
padding_height
;
int
w_in
=
w_out
*
stride_width
-
padding_width
;
int
w_in
=
w_out
*
stride_width
-
padding_width
;
data_col
+=
(
channel_out
*
col_height
+
h_out
)
*
col_width
+
w_out
;
data_col
+=
(
channel_out
*
col_height
+
h_out
)
*
col_width
+
w_out
;
data_im
+=
(
channel_in
*
im_height
+
h_in
)
*
im_width
+
w_in
;
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
int
rIdx
=
h_in
+
i
*
dilation_h
;
int
rIdx
=
h_in
+
i
*
dilation_h
;
int
cIdx
=
w_in
+
j
*
dilation_w
;
int
cIdx
=
w_in
+
j
*
dilation_w
;
int
im_idx
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
im_idx
=
(
channel_in
*
im_height
+
rIdx
)
*
im_width
+
cIdx
;
}
else
{
im_idx
=
(
rIdx
*
im_width
+
cIdx
)
*
input_channels
+
channel_in
;
}
*
data_col
=
*
data_col
=
(
rIdx
>=
im_height
||
rIdx
<
0
||
cIdx
>=
im_width
||
cIdx
<
0
)
(
rIdx
>=
im_height
||
rIdx
<
0
||
cIdx
>=
im_width
||
cIdx
<
0
)
?
0
?
0
:
data_im
[
i
*
dilation_h
*
im_width
+
j
*
dilation_w
];
:
data_im
[
i
m_idx
];
data_col
+=
col_height
*
col_width
;
data_col
+=
col_height
*
col_width
;
}
}
}
}
...
@@ -65,13 +79,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -65,13 +79,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
);
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
);
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
int
im_channels
=
im
.
dims
()[
0
];
"The dimension of col should be 5."
);
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
im_channels
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
0
]
:
im
.
dims
()[
2
]);
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
1
]
:
im
.
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
2
]
:
im
.
dims
()[
1
]);
int
filter_height
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
2
];
int
col_height
=
col
->
dims
()[
3
];
int
col_height
=
col
->
dims
()[
3
];
...
@@ -86,7 +105,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -86,7 +105,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
im2col
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
im
.
data
<
T
>
(),
num_outputs
,
im_height
,
im_width
,
dilation
[
0
],
im
.
data
<
T
>
(),
num_outputs
,
im_height
,
im_width
,
dilation
[
0
],
dilation
[
1
],
filter_height
,
filter_width
,
stride
[
0
],
stride
[
1
],
dilation
[
1
],
filter_height
,
filter_width
,
stride
[
0
],
stride
[
1
],
padding
[
0
],
padding
[
1
],
col_height
,
col_width
,
col
->
data
<
T
>
());
padding
[
0
],
padding
[
1
],
col_height
,
col_width
,
col
->
data
<
T
>
(),
data_layout
);
}
}
};
};
...
@@ -95,18 +115,27 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
...
@@ -95,18 +115,27 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
int
dilation_h
,
int
dilation_w
,
int
filter_height
,
int
dilation_h
,
int
dilation_w
,
int
filter_height
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
filter_width
,
int
stride_height
,
int
stride_width
,
int
padding_height
,
int
padding_width
,
int
col_height
,
int
padding_height
,
int
padding_width
,
int
col_height
,
int
col_width
,
T
*
data_im
)
{
int
col_width
,
T
*
data_im
,
const
DataLayout
data_layout
)
{
const
int
index
=
const
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
int
input_channels
=
n
/
im_height
/
im_width
;
if
(
index
<
n
)
{
if
(
index
<
n
)
{
T
val
=
0
;
T
val
=
0
;
int
w
=
index
%
im_width
+
padding_width
;
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
int
h
=
(
index
/
im_width
)
%
im_height
+
padding_height
;
?
index
%
im_width
+
padding_width
int
c
=
index
/
(
im_width
*
im_height
);
:
(
index
/
input_channels
)
%
im_width
+
padding_width
);
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
(
index
/
im_width
)
%
im_height
+
padding_height
:
(
index
/
input_channels
/
im_width
)
%
im_height
+
padding_height
);
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
index
/
im_width
/
im_height
:
index
%
input_channels
);
// compute the start and end of the output
// compute the start and end of the output
int
w_col_start
=
int
w_col_start
=
...
@@ -151,13 +180,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -151,13 +180,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
);
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
);
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
int
im_channels
=
im
->
dims
()[
0
];
"The dimension of col should be 5."
);
int
im_height
=
im
->
dims
()[
1
];
int
im_width
=
im
->
dims
()[
2
];
int
im_channels
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
0
]
:
im
->
dims
()[
2
]);
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
1
]
:
im
->
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
->
dims
()[
2
]
:
im
->
dims
()[
1
]);
int
filter_height
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
1
];
int
filter_width
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
2
];
int
col_height
=
col
.
dims
()[
3
];
int
col_height
=
col
.
dims
()[
3
];
...
@@ -191,7 +225,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -191,7 +225,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
col2im
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
im_height
,
im_width
,
dilation
[
0
],
num_kernels
,
col
.
data
<
T
>
(),
im_height
,
im_width
,
dilation
[
0
],
dilation
[
1
],
filter_height
,
filter_width
,
stride
[
0
],
stride
[
1
],
dilation
[
1
],
filter_height
,
filter_width
,
stride
[
0
],
stride
[
1
],
padding
[
0
],
padding
[
2
],
col_height
,
col_width
,
im
->
data
<
T
>
());
padding
[
0
],
padding
[
1
],
col_height
,
col_width
,
im
->
data
<
T
>
(),
data_layout
);
}
}
};
};
...
@@ -248,9 +283,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -248,9 +283,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
);
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
);
PADDLE_ENFORCE_EQ
(
im
.
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
int
im_channels
=
im
.
dims
()[
0
];
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
im_width
=
im
.
dims
()[
2
];
...
@@ -330,9 +368,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -330,9 +368,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
)
{
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
);
const
DataLayout
data_layout
)
{
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
);
PADDLE_ENFORCE_EQ
(
im
->
dims
().
size
(),
3
,
"The dimension of im should be 3."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
5
,
"The dimension of col should be 5."
);
int
im_channels
=
im
->
dims
()[
0
];
int
im_channels
=
im
->
dims
()[
0
];
int
im_height
=
im
->
dims
()[
1
];
int
im_height
=
im
->
dims
()[
1
];
int
im_width
=
im
->
dims
()[
2
];
int
im_width
=
im
->
dims
()[
2
];
...
...
paddle/fluid/operators/math/im2col.h
浏览文件 @
cf6919bf
...
@@ -23,6 +23,8 @@ namespace paddle {
...
@@ -23,6 +23,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
using
DataLayout
=
framework
::
DataLayout
;
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum
class
ColFormat
{
kCFO
=
0
,
kOCF
=
1
};
enum
class
ColFormat
{
kCFO
=
0
,
kOCF
=
1
};
...
@@ -86,7 +88,8 @@ class Im2ColFunctor {
...
@@ -86,7 +88,8 @@ class Im2ColFunctor {
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
);
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
);
};
};
template
<
ColFormat
Format
,
typename
DeviceContext
,
typename
T
>
template
<
ColFormat
Format
,
typename
DeviceContext
,
typename
T
>
...
@@ -95,7 +98,8 @@ class Col2ImFunctor {
...
@@ -95,7 +98,8 @@ class Col2ImFunctor {
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
);
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/im2col_cfo_cpu.h
浏览文件 @
cf6919bf
...
@@ -30,10 +30,14 @@ inline void im2col_common(const framework::Tensor& im,
...
@@ -30,10 +30,14 @@ inline void im2col_common(const framework::Tensor& im,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
framework
::
Tensor
*
col
,
int
im_channels
=
im
.
dims
()[
0
];
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
)
{
int
im_height
=
im
.
dims
()[
1
];
int
im_channels
=
int
im_width
=
im
.
dims
()[
2
];
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
0
]
:
im
.
dims
()[
2
]);
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
1
]
:
im
.
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
2
]
:
im
.
dims
()[
1
]);
int
filter_height
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_height
=
col
->
dims
()[
3
];
...
@@ -50,8 +54,14 @@ inline void im2col_common(const framework::Tensor& im,
...
@@ -50,8 +54,14 @@ inline void im2col_common(const framework::Tensor& im,
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
im_idx
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
}
else
{
im_idx
=
(
im_row_idx
*
im_width
+
im_col_idx
)
*
im_channels
+
c_im
;
}
int
col_idx
=
(
c
*
output_height
+
h
)
*
output_width
+
w
;
int
col_idx
=
(
c
*
output_height
+
h
)
*
output_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
static_cast
<
T
>
(
0
)
?
static_cast
<
T
>
(
0
)
...
@@ -65,11 +75,15 @@ inline void im2col_common(const framework::Tensor& im,
...
@@ -65,11 +75,15 @@ inline void im2col_common(const framework::Tensor& im,
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
*/
*/
template
<
typename
T
>
template
<
typename
T
>
inline
void
im2col_sh1sw1dh1dw1ph0pw0
(
const
framework
::
Tensor
&
im
,
inline
void
im2col_sh1sw1dh1dw1ph0pw0
(
framework
::
Tensor
*
col
)
{
const
framework
::
Tensor
&
im
,
framework
::
Tensor
*
col
,
int
im_channels
=
im
.
dims
()[
0
];
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
)
{
int
im_height
=
im
.
dims
()[
1
];
int
im_channels
=
int
im_width
=
im
.
dims
()[
2
];
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
0
]
:
im
.
dims
()[
2
]);
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
1
]
:
im
.
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
2
]
:
im
.
dims
()[
1
]);
int
filter_height
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_height
=
col
->
dims
()[
3
];
...
@@ -89,7 +103,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
...
@@ -89,7 +103,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
const
T
*
src_data
=
src_data_ic
;
const
T
*
src_data
=
src_data_ic
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
std
::
memcpy
(
dst_data
,
src_data
+
kw
,
copy_size
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
std
::
memcpy
(
dst_data
,
src_data
+
kw
,
copy_size
);
}
else
{
for
(
int
kow
=
0
;
kow
<
output_width
;
++
kow
)
{
dst_data
[
kow
]
=
im_data
[((
oh
+
kh
)
*
im_width
+
kw
+
kow
)
*
im_channels
+
ic
];
}
}
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
}
}
src_data
=
src_data
+
im_width
;
src_data
=
src_data
+
im_width
;
...
@@ -107,10 +128,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
...
@@ -107,10 +128,14 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
*/
*/
template
<
typename
T
>
template
<
typename
T
>
inline
void
im2col_sh1sw1dh1dw1ph1pw1
(
const
framework
::
Tensor
&
im
,
inline
void
im2col_sh1sw1dh1dw1ph1pw1
(
const
framework
::
Tensor
&
im
,
framework
::
Tensor
*
col
)
{
framework
::
Tensor
*
col
,
int
im_channels
=
im
.
dims
()[
0
];
const
DataLayout
data_layout
)
{
int
im_height
=
im
.
dims
()[
1
];
int
im_channels
=
int
im_width
=
im
.
dims
()[
2
];
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
0
]
:
im
.
dims
()[
2
]);
int
im_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
1
]
:
im
.
dims
()[
0
]);
int
im_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
im
.
dims
()[
2
]
:
im
.
dims
()[
1
]);
int
filter_height
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_height
=
col
->
dims
()[
3
];
...
@@ -180,7 +205,17 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
...
@@ -180,7 +205,17 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
continue
;
continue
;
}
}
std
::
memcpy
(
dst_data
+
plw
,
src_data
,
copy_size
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
std
::
memcpy
(
dst_data
+
plw
,
src_data
,
copy_size
);
}
else
{
for
(
int
kow
=
0
;
kow
<
output_width
-
plw
-
prw
;
++
kow
)
{
dst_data
[
plw
+
kow
]
=
im_data
[(((
oh
-
plh
>
0
?
oh
-
plh
:
0
)
+
kh
)
*
im_width
+
kow
)
*
im_channels
+
ic
];
}
}
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
src_data
=
src_data
+
im_width
;
src_data
=
src_data
+
im_width
;
}
}
...
@@ -226,19 +261,49 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
...
@@ -226,19 +261,49 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
// TODO(TJ): reuse plw-kw outside this for
// TODO(TJ): reuse plw-kw outside this for
// try to unify
// try to unify
for
(
int
kw
=
0
;
kw
<
plw
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
plw
;
++
kw
)
{
std
::
memcpy
(
dst_data
+
(
plw
-
kw
),
src_data
,
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
sizeof
(
T
)
*
(
output_width
-
(
plw
-
kw
)));
std
::
memcpy
(
dst_data
+
(
plw
-
kw
),
src_data
,
sizeof
(
T
)
*
(
output_width
-
(
plw
-
kw
)));
}
else
{
for
(
int
kow
=
0
;
kow
<
output_width
-
(
plw
-
kw
);
++
kow
)
{
dst_data
[
plw
-
kw
+
kow
]
=
im_data
[(((
oh
-
plh
>
0
?
oh
-
plh
:
0
)
+
kh
)
*
im_width
+
kow
)
*
im_channels
+
ic
];
}
}
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
}
}
for
(
int
kw
=
plw
;
kw
<
filter_width
-
prw
;
++
kw
)
{
for
(
int
kw
=
plw
;
kw
<
filter_width
-
prw
;
++
kw
)
{
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
sizeof
(
T
)
*
output_width
);
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
sizeof
(
T
)
*
output_width
);
}
else
{
for
(
int
kow
=
0
;
kow
<
output_width
;
++
kow
)
{
dst_data
[
kow
]
=
im_data
[(((
oh
-
plh
>
0
?
oh
-
plh
:
0
)
+
kh
)
*
im_width
+
kw
-
plw
+
kow
)
*
im_channels
+
ic
];
}
}
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
}
}
int
i
=
1
;
int
i
=
1
;
for
(
int
kw
=
filter_width
-
prw
;
kw
<
filter_width
;
++
kw
,
++
i
)
{
for
(
int
kw
=
filter_width
-
prw
;
kw
<
filter_width
;
++
kw
,
++
i
)
{
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
sizeof
(
T
)
*
(
output_width
-
i
));
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
sizeof
(
T
)
*
(
output_width
-
i
));
}
else
{
for
(
int
kow
=
0
;
kow
<
output_width
-
i
;
++
kow
)
{
dst_data
[
kow
]
=
im_data
[(((
oh
-
plh
>
0
?
oh
-
plh
:
0
)
+
kh
)
*
im_width
+
kw
-
plw
+
kow
)
*
im_channels
+
ic
];
}
}
dst_data
=
dst_data
+
col_matrix_width
;
dst_data
=
dst_data
+
col_matrix_width
;
}
}
src_data
=
src_data
+
im_width
;
src_data
=
src_data
+
im_width
;
...
...
paddle/fluid/operators/math/vol2col.cc
浏览文件 @
cf6919bf
...
@@ -32,16 +32,21 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
...
@@ -32,16 +32,21 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
const
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
framework
::
Tensor
*
col
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
"The dimension of vol should be 4."
);
"The dimension of vol should be 4."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
"The dimension of col should be 7."
);
"The dimension of col should be 7."
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_depth
=
vol
.
dims
()[
1
];
int
input_channels
=
int
input_height
=
vol
.
dims
()[
2
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
0
]
:
vol
.
dims
()[
3
]);
int
input_width
=
vol
.
dims
()[
3
];
int
input_depth
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
1
]
:
vol
.
dims
()[
0
]);
int
input_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
2
]
:
vol
.
dims
()[
1
]);
int
input_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
3
]
:
vol
.
dims
()[
2
]);
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
3
];
int
filter_width
=
col
->
dims
()[
3
];
...
@@ -59,6 +64,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
...
@@ -59,6 +64,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
PADDLE_ENFORCE_EQ
((
input_depth
+
pad_d_forth
+
pad_d_back
-
PADDLE_ENFORCE_EQ
((
input_depth
+
pad_d_forth
+
pad_d_back
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
strides
[
0
]
+
...
@@ -97,10 +103,16 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
...
@@ -97,10 +103,16 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int
col_idx
=
int
col_idx
=
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
int
vol_idx
=
int
vol_idx
;
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
input_width
+
vol_idx
=
((
c_in
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
w_pad
;
input_width
+
w_pad
;
}
else
{
vol_idx
=
((
d_pad
*
input_height
+
h_pad
)
*
input_width
+
w_pad
)
*
input_channels
+
c_in
;
}
col_data
[
col_idx
]
=
col_data
[
col_idx
]
=
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
(
h_pad
<
0
||
h_pad
>=
input_height
||
w_pad
<
0
||
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
w_pad
>=
input_width
||
d_pad
<
0
||
d_pad
>=
input_depth
)
...
@@ -126,16 +138,21 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
...
@@ -126,16 +138,21 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
framework
::
Tensor
*
vol
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
"The dimension of vol should be 4."
);
"The dimension of vol should be 4."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
"The dimension of col should be 7."
);
"The dimension of col should be 7."
);
int
input_channels
=
vol
->
dims
()[
0
];
int
input_depth
=
vol
->
dims
()[
1
];
int
input_channels
=
int
input_height
=
vol
->
dims
()[
2
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
0
]
:
vol
->
dims
()[
3
]);
int
input_width
=
vol
->
dims
()[
3
];
int
input_depth
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
1
]
:
vol
->
dims
()[
0
]);
int
input_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
2
]
:
vol
->
dims
()[
1
]);
int
input_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
3
]
:
vol
->
dims
()[
2
]);
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
3
];
...
@@ -191,11 +208,17 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
...
@@ -191,11 +208,17 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
if
(
h_pad
>=
0
&&
h_pad
<
input_height
&&
w_pad
>=
0
&&
if
(
h_pad
>=
0
&&
h_pad
<
input_height
&&
w_pad
>=
0
&&
w_pad
<
input_width
&&
d_pad
>=
0
&&
d_pad
<
input_depth
)
{
w_pad
<
input_width
&&
d_pad
>=
0
&&
d_pad
<
input_depth
)
{
int
vol_idx
=
int
vol_idx
;
((
cIm
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
input_width
+
vol_idx
=
((
cIm
*
input_depth
+
d_pad
)
*
input_height
+
h_pad
)
*
w_pad
;
input_width
+
w_pad
;
}
else
{
vol_idx
=
((
d_pad
*
input_height
+
h_pad
)
*
input_width
+
w_pad
)
*
input_channels
+
cIm
;
}
int
col_idx
=
int
col_idx
=
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
((
c
*
output_depth
+
d
)
*
output_height
+
h
)
*
output_width
+
w
;
w
;
...
...
paddle/fluid/operators/math/vol2col.cu
浏览文件 @
cf6919bf
...
@@ -28,7 +28,12 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
...
@@ -28,7 +28,12 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_col
)
{
int
output_width
,
T
*
data_col
,
const
DataLayout
data_layout
)
{
int
input_channels
=
num_kernels
/
output_detph
/
output_height
/
output_width
;
int
channels_col
=
input_channels
*
filter_depth
*
filter_height
*
filter_width
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_out
=
index
%
output_width
;
int
w_out
=
index
%
output_width
;
...
@@ -43,18 +48,22 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
...
@@ -43,18 +48,22 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
data_col
+=
((
channel_out
*
output_detph
+
d_out
)
*
output_height
+
h_out
)
*
data_col
+=
((
channel_out
*
output_detph
+
d_out
)
*
output_height
+
h_out
)
*
output_width
+
output_width
+
w_out
;
w_out
;
data_vol
+=
((
channel_in
*
depth
+
d_in
)
*
height
+
h_in
)
*
width
+
w_in
;
for
(
int
k
=
0
;
k
<
filter_depth
;
++
k
)
{
for
(
int
k
=
0
;
k
<
filter_depth
;
++
k
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
int
d
=
d_in
+
k
*
dilation_d
;
int
d
=
d_in
+
k
*
dilation_d
;
int
h
=
h_in
+
i
*
dilation_h
;
int
h
=
h_in
+
i
*
dilation_h
;
int
w
=
w_in
+
j
*
dilation_w
;
int
w
=
w_in
+
j
*
dilation_w
;
int
col_idx
=
(
k
*
dilation_d
*
height
+
i
*
dilation_h
)
*
width
+
int
vol_idx
;
j
*
dilation_w
;
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
vol_idx
=
((
channel_in
*
depth
+
d
)
*
height
+
h
)
*
width
+
w
;
}
else
{
vol_idx
=
((
d
*
height
+
h
)
*
width
+
w
)
*
input_channels
+
channel_in
;
}
*
data_col
=
(
d
>=
0
&&
d
<
depth
&&
h
>=
0
&&
h
<
height
&&
w
>=
0
&&
*
data_col
=
(
d
>=
0
&&
d
<
depth
&&
h
>=
0
&&
h
<
height
&&
w
>=
0
&&
w
<
width
)
w
<
width
)
?
data_vol
[
c
ol_idx
]
?
data_vol
[
v
ol_idx
]
:
0
;
:
0
;
data_col
+=
output_detph
*
output_height
*
output_width
;
data_col
+=
output_detph
*
output_height
*
output_width
;
}
}
...
@@ -64,7 +73,10 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
...
@@ -64,7 +73,10 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
}
}
/*
/*
* im = [input_channels,intpu_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
* output_depth, output_height, output_width]
...
@@ -76,15 +88,21 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
...
@@ -76,15 +88,21 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Tensor
&
vol
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
framework
::
Tensor
*
col
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
);
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
);
"The dimension of vol should be 4."
);
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
"The dimension of col should be 7."
);
int
input_channels
=
vol
.
dims
()[
0
];
int
input_channels
=
int
input_depth
=
vol
.
dims
()[
1
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
0
]
:
vol
.
dims
()[
3
]);
int
input_height
=
vol
.
dims
()[
2
];
int
input_depth
=
int
input_width
=
vol
.
dims
()[
3
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
1
]
:
vol
.
dims
()[
0
]);
int
input_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
2
]
:
vol
.
dims
()[
1
]);
int
input_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
.
dims
()[
3
]
:
vol
.
dims
()[
2
]);
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
3
];
int
filter_width
=
col
->
dims
()[
3
];
...
@@ -130,7 +148,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
...
@@ -130,7 +148,8 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
());
pad_w_left
,
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
(),
data_layout
);
}
}
};
};
...
@@ -141,18 +160,27 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
...
@@ -141,18 +160,27 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
filter_width
,
int
stride_depth
,
int
stride_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
stride_width
,
int
padding_depth
,
int
padding_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
padding_width
,
int
output_detph
,
int
output_height
,
int
output_width
,
T
*
data_vol
)
{
int
output_width
,
T
*
data_vol
,
const
DataLayout
data_layout
)
{
const
int
d_filter_depth
=
dilation_d
*
(
filter_depth
-
1
)
+
1
;
const
int
d_filter_depth
=
dilation_d
*
(
filter_depth
-
1
)
+
1
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_height
=
dilation_h
*
(
filter_height
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
const
int
d_filter_width
=
dilation_w
*
(
filter_width
-
1
)
+
1
;
int
input_channels
=
num_kernels
/
depth
/
height
/
width
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
src_val
=
0
;
T
src_val
=
0
;
int
w
=
index
%
width
+
padding_width
;
int
w
=
(
data_layout
==
DataLayout
::
kNCHW
int
h
=
(
index
/
width
)
%
height
+
padding_height
;
?
index
%
width
+
padding_width
int
d
=
(
index
/
width
/
height
)
%
depth
+
padding_depth
;
:
(
index
/
input_channels
)
%
width
+
padding_width
);
int
c
=
index
/
width
/
height
/
depth
;
int
h
=
(
data_layout
==
DataLayout
::
kNCHW
?
(
index
/
width
)
%
height
+
padding_height
:
(
index
/
input_channels
/
width
)
%
height
+
padding_height
);
int
d
=
(
data_layout
==
DataLayout
::
kNCHW
?
(
index
/
width
/
height
)
%
depth
+
padding_depth
:
index
/
input_channels
/
width
/
height
+
padding_depth
);
int
c
=
(
data_layout
==
DataLayout
::
kNCHW
?
index
/
width
/
height
/
depth
:
index
%
input_channels
);
// compute the start and end of the output
// compute the start and end of the output
int
w_col_start
=
int
w_col_start
=
...
@@ -196,7 +224,10 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
...
@@ -196,7 +224,10 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
}
}
/*
/*
* im = [input_channels, input_depth, input_height, input_width]
* im = [input_channels,intpu_depth, input_height, input_width] for
* channels_first
* im = [input_depth, input_height, input_width, input_channels] for
* channels_last
* col =
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
* output_depth, output_height, output_width]
...
@@ -208,15 +239,21 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
...
@@ -208,15 +239,21 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
framework
::
Tensor
*
vol
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
);
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
);
"The dimension of vol should be 4."
);
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
"The dimension of col should be 7."
);
int
input_channels
=
vol
->
dims
()[
0
];
int
input_channels
=
int
input_depth
=
vol
->
dims
()[
1
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
0
]
:
vol
->
dims
()[
3
]);
int
input_height
=
vol
->
dims
()[
2
];
int
input_depth
=
int
input_width
=
vol
->
dims
()[
3
];
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
1
]
:
vol
->
dims
()[
0
]);
int
input_height
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
2
]
:
vol
->
dims
()[
1
]);
int
input_width
=
(
data_layout
==
DataLayout
::
kNCHW
?
vol
->
dims
()[
3
]
:
vol
->
dims
()[
2
]);
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
3
];
...
@@ -263,7 +300,8 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
...
@@ -263,7 +300,8 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
());
pad_w_left
,
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
(),
data_layout
);
}
}
};
};
...
...
paddle/fluid/operators/math/vol2col.h
浏览文件 @
cf6919bf
...
@@ -22,6 +22,9 @@ limitations under the License. */
...
@@ -22,6 +22,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
using
DataLayout
=
framework
::
DataLayout
;
/*
/*
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
* seven dimensions in the Vol2ColFunctor calculation,
* seven dimensions in the Vol2ColFunctor calculation,
...
@@ -70,8 +73,8 @@ class Vol2ColFunctor {
...
@@ -70,8 +73,8 @@ class Vol2ColFunctor {
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
framework
::
Tensor
*
col
)
const
;
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
)
const
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -80,8 +83,8 @@ class Col2VolFunctor {
...
@@ -80,8 +83,8 @@ class Col2VolFunctor {
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
framework
::
Tensor
*
vol
)
const
;
const
DataLayout
data_layout
=
DataLayout
::
kNCHW
)
const
;
};
};
}
// namespace math
}
// namespace math
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
cf6919bf
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py
浏览文件 @
cf6919bf
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py
浏览文件 @
cf6919bf
...
@@ -18,10 +18,19 @@ import unittest
...
@@ -18,10 +18,19 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
op_test
import
OpTest
def
conv3dtranspose_forward_naive
(
input_
,
filter_
,
attrs
):
def
conv3dtranspose_forward_naive
(
input_
,
filter_
,
attrs
):
padding_algorithm
=
attrs
[
'padding_algorithm'
]
if
padding_algorithm
not
in
[
"SAME"
,
"VALID"
,
"EXPLICIT"
]:
raise
ValueError
(
"Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'."
%
str
(
padding_algorithm
))
if
attrs
[
'data_format'
]
==
'NHWC'
:
input_
=
np
.
transpose
(
input_
,
[
0
,
4
,
1
,
2
,
3
])
in_n
,
in_c
,
in_d
,
in_h
,
in_w
=
input_
.
shape
in_n
,
in_c
,
in_d
,
in_h
,
in_w
=
input_
.
shape
f_c
,
f_out_c
,
f_d
,
f_h
,
f_w
=
filter_
.
shape
f_c
,
f_out_c
,
f_d
,
f_h
,
f_w
=
filter_
.
shape
groups
=
attrs
[
'groups'
]
groups
=
attrs
[
'groups'
]
...
@@ -32,6 +41,39 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
...
@@ -32,6 +41,39 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
stride
,
pad
,
dilations
=
attrs
[
'strides'
],
attrs
[
'paddings'
],
attrs
[
stride
,
pad
,
dilations
=
attrs
[
'strides'
],
attrs
[
'paddings'
],
attrs
[
'dilations'
]
'dilations'
]
def
_get_padding_with_SAME
(
input_shape
,
kernel_size
,
kernel_stride
):
padding
=
[]
for
input_size
,
filter_size
,
stride_size
in
zip
(
input_shape
,
kernel_size
,
kernel_stride
):
out_size
=
int
((
input_size
+
stride_size
-
1
)
/
stride_size
)
pad_sum
=
np
.
max
((
(
out_size
-
1
)
*
stride_size
+
filter_size
-
input_size
,
0
))
pad_0
=
int
(
pad_sum
/
2
)
pad_1
=
int
(
pad_sum
-
pad_0
)
padding
.
append
(
pad_0
)
padding
.
append
(
pad_1
)
return
padding
ksize
=
filter_
.
shape
[
2
:
5
]
if
padding_algorithm
==
"VALID"
:
pad
=
[
0
,
0
,
0
,
0
,
0
,
0
]
elif
padding_algorithm
==
"SAME"
:
dilation
=
[
1
,
1
,
1
]
input_data_shape
=
[]
if
attrs
[
'data_format'
]
==
"NCHW"
:
input_data_shape
=
input_
.
shape
[
2
:
5
]
elif
attrs
[
'data_format'
]
==
"NHWC"
:
input_data_shape
=
input_
.
shape
[
1
:
4
]
pad
=
_get_padding_with_SAME
(
input_data_shape
,
ksize
,
stride
)
pad_d_0
,
pad_d_1
=
pad
[
0
],
pad
[
0
]
pad_h_0
,
pad_h_1
=
pad
[
1
],
pad
[
1
]
pad_w_0
,
pad_w_1
=
pad
[
2
],
pad
[
2
]
if
len
(
pad
)
==
6
:
pad_d_0
,
pad_d_1
=
pad
[
0
],
pad
[
1
]
pad_h_0
,
pad_h_1
=
pad
[
2
],
pad
[
3
]
pad_w_0
,
pad_w_1
=
pad
[
4
],
pad
[
5
]
d_bolck_d
=
dilations
[
0
]
*
(
f_d
-
1
)
+
1
d_bolck_d
=
dilations
[
0
]
*
(
f_d
-
1
)
+
1
d_bolck_h
=
dilations
[
1
]
*
(
f_h
-
1
)
+
1
d_bolck_h
=
dilations
[
1
]
*
(
f_h
-
1
)
+
1
d_bolck_w
=
dilations
[
2
]
*
(
f_w
-
1
)
+
1
d_bolck_w
=
dilations
[
2
]
*
(
f_w
-
1
)
+
1
...
@@ -62,8 +104,10 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
...
@@ -62,8 +104,10 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
out
[
n
,
g
*
f_out_c
+
k
,
d1
:
d2
:
dilations
[
0
],
i1
:
i2
:
out
[
n
,
g
*
f_out_c
+
k
,
d1
:
d2
:
dilations
[
0
],
i1
:
i2
:
dilations
[
1
],
j1
:
j2
:
dilations
[
2
]]
+=
tmp_out
dilations
[
1
],
j1
:
j2
:
dilations
[
2
]]
+=
tmp_out
out
=
out
[:,
:,
pad
[
0
]:
out_d
-
pad
[
0
],
pad
[
1
]:
out_h
-
pad
[
1
],
pad
[
2
]:
out_w
-
out
=
out
[:,
:,
pad_d_0
:
out_d
-
pad_d_1
,
pad_h_0
:
out_h
-
pad_h_1
,
pad_w_0
:
pad
[
2
]]
out_w
-
pad_w_1
]
if
attrs
[
'data_format'
]
==
'NHWC'
:
out
=
np
.
transpose
(
out
,
[
0
,
2
,
3
,
4
,
1
])
return
out
return
out
...
@@ -71,6 +115,9 @@ class TestConv3dTransposeOp(OpTest):
...
@@ -71,6 +115,9 @@ class TestConv3dTransposeOp(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
# init as conv transpose
# init as conv transpose
self
.
use_cudnn
=
False
self
.
use_cudnn
=
False
self
.
data_format
=
'NCHW'
self
.
pad
=
[
0
,
0
,
0
]
self
.
padding_algorithm
=
"EXPLICIT"
self
.
init_op_type
()
self
.
init_op_type
()
self
.
init_test_case
()
self
.
init_test_case
()
...
@@ -81,10 +128,11 @@ class TestConv3dTransposeOp(OpTest):
...
@@ -81,10 +128,11 @@ class TestConv3dTransposeOp(OpTest):
self
.
attrs
=
{
self
.
attrs
=
{
'strides'
:
self
.
stride
,
'strides'
:
self
.
stride
,
'paddings'
:
self
.
pad
,
'paddings'
:
self
.
pad
,
'padding_algorithm'
:
self
.
padding_algorithm
,
'dilations'
:
self
.
dilations
,
'dilations'
:
self
.
dilations
,
'groups'
:
self
.
groups
,
'groups'
:
self
.
groups
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_cudnn'
:
self
.
use_cudnn
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
self
.
data_format
}
}
output
=
conv3dtranspose_forward_naive
(
input_
,
filter_
,
output
=
conv3dtranspose_forward_naive
(
input_
,
filter_
,
...
@@ -154,7 +202,7 @@ class TestConv3dTransposeOp(OpTest):
...
@@ -154,7 +202,7 @@ class TestConv3dTransposeOp(OpTest):
self
.
op_type
=
"conv3d_transpose"
self
.
op_type
=
"conv3d_transpose"
class
TestWithPad
(
TestConv3dTransposeOp
):
class
TestWith
Symmetric
Pad
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
...
@@ -165,6 +213,39 @@ class TestWithPad(TestConv3dTransposeOp):
...
@@ -165,6 +213,39 @@ class TestWithPad(TestConv3dTransposeOp):
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
class
TestWithAsymmetricPad
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
0
,
1
,
0
,
1
,
2
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
5
,
5
,
5
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
class
TestWithSAMEPad
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
5
,
5
,
5
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
padding_algorithm
=
'SAME'
class
TestWithVALIDPad
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
5
,
5
,
5
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
padding_algorithm
=
'VALID'
class
TestWithGroups
(
TestConv3dTransposeOp
):
class
TestWithGroups
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
pad
=
[
1
,
1
,
1
]
...
@@ -198,6 +279,78 @@ class TestWithDilation(TestConv3dTransposeOp):
...
@@ -198,6 +279,78 @@ class TestWithDilation(TestConv3dTransposeOp):
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
class
Test_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
,
0
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
class
TestWithSymmetricPad_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
class
TestWithAsymmetricPad_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
0
,
1
,
0
,
1
,
2
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
class
TestWithGroups_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
2
self
.
input_size
=
[
2
,
5
,
5
,
5
,
4
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
3
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
class
TestWithStride_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
2
,
2
,
2
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NCDHW
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
class
TestWithDilation_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
2
,
2
,
2
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NCDHW
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
# ------------ test_cudnn ------------
# ------------ test_cudnn ------------
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
"core is not compiled with CUDA"
)
...
@@ -209,7 +362,7 @@ class TestCUDNN(TestConv3dTransposeOp):
...
@@ -209,7 +362,7 @@ class TestCUDNN(TestConv3dTransposeOp):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
"core is not compiled with CUDA"
)
class
TestCUDNNWith
Pad
(
TestWith
Pad
):
class
TestCUDNNWith
SymmetricPad
(
TestWithSymmetric
Pad
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
...
@@ -224,6 +377,57 @@ class TestCUDNNWithPad(TestWithPad):
...
@@ -224,6 +377,57 @@ class TestCUDNNWithPad(TestWithPad):
self
.
op_type
=
"conv3d_transpose"
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithAsymmetricPad
(
TestWithAsymmetricPad
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
,
0
,
0
,
2
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
4
,
4
,
4
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithSAMEPad
(
TestWithSAMEPad
):
def
init_test_case
(
self
):
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
5
,
5
,
5
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
padding_algorithm
=
'SAME'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithVALIDPad
(
TestWithVALIDPad
):
def
init_test_case
(
self
):
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
3
,
5
,
5
,
5
]
# NCDHW
f_c
=
self
.
input_size
[
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
padding_algorithm
=
'VALID'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
"core is not compiled with CUDA"
)
class
TestCUDNNWithStride
(
TestWithStride
):
class
TestCUDNNWithStride
(
TestWithStride
):
...
@@ -272,5 +476,222 @@ class TestCUDNNWithGroups(TestWithGroups):
...
@@ -272,5 +476,222 @@ class TestCUDNNWithGroups(TestWithGroups):
# def init_op_type(self):
# def init_op_type(self):
# self.op_type = "conv3d_transpose"
# self.op_type = "conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNN_NHWC
(
TestConv3dTransposeOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
0
,
0
,
0
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithSymmetricPad_NHWC
(
TestWithSymmetricPad
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithAsymmetricPad_NHWC
(
TestWithAsymmetricPad
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
0
,
1
,
0
,
0
,
2
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NDHWC
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithStride_NHWC
(
TestWithStride
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
2
,
2
,
2
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
2
,
5
,
5
,
5
,
3
]
# NCDHW
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNWithGroups_NHWC
(
TestWithGroups
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
,
1
]
self
.
stride
=
[
1
,
1
,
1
]
self
.
dilations
=
[
1
,
1
,
1
]
self
.
groups
=
2
self
.
input_size
=
[
2
,
5
,
5
,
5
,
4
]
# NCHW
f_c
=
self
.
input_size
[
-
1
]
self
.
filter_size
=
[
f_c
,
3
,
3
,
3
,
3
]
self
.
data_format
=
'NHWC'
def
init_op_type
(
self
):
self
.
use_cudnn
=
True
self
.
op_type
=
"conv3d_transpose"
class
TestConv3dTransposeAPI
(
OpTest
):
def
test_case1
(
self
):
data1
=
fluid
.
layers
.
data
(
name
=
'data1'
,
shape
=
[
3
,
5
,
5
,
5
],
dtype
=
'float32'
)
data2
=
fluid
.
layers
.
data
(
name
=
'data2'
,
shape
=
[
5
,
5
,
5
,
3
],
dtype
=
'float32'
)
out1
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data1
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
data_format
=
'NCDHW'
)
out2
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data2
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
data_format
=
'NDHWC'
)
out3
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data1
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
[[
0
,
0
],
[
0
,
0
],
[
1
,
1
],
[
0
,
0
],
[
1
,
1
]],
data_format
=
'NCDHW'
)
out4
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data2
,
groups
=
3
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
[[
0
,
0
],
[
0
,
0
],
[
1
,
1
],
[
1
,
2
],
[
0
,
0
]],
data_format
=
'NDHWC'
)
out5
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data2
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
'SAME'
,
data_format
=
'NCDHW'
)
out6
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data2
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
'VALID'
,
data_format
=
'NDHWC'
)
out7
=
fluid
.
layers
.
conv3d_transpose
(
input
=
data2
,
groups
=
1
,
num_filters
=
6
,
output_size
=
[
7
,
7
,
7
],
padding
=
[
0
,
0
,
0
],
data_format
=
'NDHWC'
)
data1_np
=
np
.
random
.
random
((
2
,
3
,
5
,
5
,
5
)).
astype
(
"float32"
)
data2_np
=
np
.
random
.
random
((
2
,
5
,
5
,
5
,
3
)).
astype
(
"float32"
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
results
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"data1"
:
data1_np
,
"data2"
:
data2_np
},
fetch_list
=
[
out1
,
out2
,
out3
,
out4
,
out5
,
out6
,
out7
],
return_numpy
=
True
)
self
.
assertIsNotNone
(
results
[
0
])
self
.
assertIsNotNone
(
results
[
1
])
self
.
assertIsNotNone
(
results
[
2
])
self
.
assertIsNotNone
(
results
[
3
])
self
.
assertIsNotNone
(
results
[
4
])
self
.
assertIsNotNone
(
results
[
5
])
self
.
assertIsNotNone
(
results
[
6
])
class
TestConv3dTransposeOpException
(
OpTest
):
def
test_exception
(
self
):
data
=
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
[
3
,
5
,
5
,
5
],
dtype
=
"float32"
)
def
attr_data_format
():
out
=
fluid
.
layers
.
conv2d_transpose
(
input
=
data
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
data_format
=
"NCDW"
)
self
.
assertRaises
(
ValueError
,
attr_data_format
)
def
attr_padding_str
():
out
=
fluid
.
layers
.
conv2d_transpose
(
input
=
data
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
'Vald'
)
self
.
assertRaises
(
ValueError
,
attr_padding_str
)
def
attr_padding_list
():
out
=
fluid
.
layers
.
conv2d_transpose
(
input
=
data
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
1
,
1
]])
self
.
assertRaises
(
ValueError
,
attr_padding_list
)
def
attr_padding_with_data_format
():
out
=
fluid
.
layers
.
conv2d_transpose
(
input
=
data
,
groups
=
1
,
num_filters
=
6
,
filter_size
=
3
,
padding
=
[[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
1
,
0
],
[
1
,
1
]],
data_format
=
'NDHWC'
)
self
.
assertRaises
(
ValueError
,
attr_padding_with_data_format
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录