Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d0599511
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d0599511
编写于
2月 02, 2018
作者:
Z
Zhaolong Xing
提交者:
GitHub
2月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7885 from NHZlX/add_depthwiseConv_op_gpu
Add depthwise conv op gpu
上级
71bd0dfa
3074ae7b
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
504 addition
and
4 deletion
+504
-4
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+4
-1
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+16
-0
paddle/operators/conv_op.cu.cc
paddle/operators/conv_op.cu.cc
+10
-0
paddle/operators/conv_op.h
paddle/operators/conv_op.h
+68
-0
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+1
-0
paddle/operators/math/depthwise_conv.cu
paddle/operators/math/depthwise_conv.cu
+311
-0
paddle/operators/math/depthwise_conv.h
paddle/operators/math/depthwise_conv.h
+60
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+10
-3
python/paddle/v2/fluid/tests/test_conv2d_op.py
python/paddle/v2/fluid/tests/test_conv2d_op.py
+24
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
d0599511
...
...
@@ -158,7 +158,10 @@ op_library(parallel_do_op DEPS executor)
# Regist multiple Kernel to pybind
if
(
WITH_GPU
)
op_library
(
conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col
)
op_library
(
conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
vol2col depthwise_conv
)
op_library
(
edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function
)
op_library
(
pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling
)
op_library
(
conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
...
...
paddle/operators/conv_op.cc
浏览文件 @
d0599511
...
...
@@ -318,9 +318,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
conv2d
,
ops
::
ConvOp
,
ops
::
Conv2DOpMaker
,
conv2d_grad
,
ops
::
ConvOpGrad
);
// depthwise convolution op
REGISTER_OP
(
depthwise_conv2d
,
ops
::
ConvOp
,
ops
::
Conv2DOpMaker
,
depthwise_conv2d_grad
,
ops
::
ConvOpGrad
);
REGISTER_OP
(
conv3d
,
ops
::
ConvOp
,
ops
::
Conv3DOpMaker
,
conv3d_grad
,
ops
::
ConvOpGrad
);
// depthwise conv kernel
// TODO(xingzhaolong): neon kernel for mobile
REGISTER_OP_CPU_KERNEL
(
depthwise_conv2d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
depthwise_conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv2d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
...
...
paddle/operators/conv_op.cu.cc
浏览文件 @
d0599511
...
...
@@ -16,6 +16,16 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
depthwise_conv2d
,
ops
::
DepthwiseConvKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DepthwiseConvKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
depthwise_conv2d_grad
,
ops
::
DepthwiseConvGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DepthwiseConvGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
conv2d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
...
...
paddle/operators/conv_op.h
浏览文件 @
d0599511
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/depthwise_conv.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/vol2col.h"
...
...
@@ -350,5 +351,72 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"Input"
);
Tensor
filter
=
*
context
.
Input
<
Tensor
>
(
"Filter"
);
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
output
->
dims
()[
1
]
%
input
->
dims
()[
1
],
0
,
"The output channels must be a multiple of the input channels"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
math
::
DepthwiseConvFunctor
<
DeviceContext
,
T
>
depthwiseConv
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
depthwiseConv
(
dev_ctx
,
*
input
,
filter
,
strides
,
paddings
,
output
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"Input"
);
const
Tensor
*
output_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
Tensor
*
input_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
Tensor
*
filter_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter"
));
Tensor
filter
=
*
context
.
Input
<
Tensor
>
(
"Filter"
);
if
(
!
input_grad
&&
!
filter_grad
)
return
;
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
math
::
DepthwiseConvInputGradFunctor
<
DeviceContext
,
T
>
depthwiseConvInputGrad
;
math
::
DepthwiseConvFilterGradFunctor
<
DeviceContext
,
T
>
depthwiseConvFilterGrad
;
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
set_zero
(
dev_ctx
,
input_grad
,
static_cast
<
T
>
(
0
));
depthwiseConvInputGrad
(
dev_ctx
,
*
input
,
filter
,
*
output_grad
,
strides
,
paddings
,
input_grad
);
}
if
(
filter_grad
)
{
filter_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
set_zero
(
dev_ctx
,
filter_grad
,
static_cast
<
T
>
(
0
));
depthwiseConvFilterGrad
(
dev_ctx
,
*
input
,
*
output_grad
,
strides
,
paddings
,
filter_grad
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/math/CMakeLists.txt
浏览文件 @
d0599511
...
...
@@ -8,6 +8,7 @@ if(WITH_GPU)
nv_library
(
softmax SRCS softmax.cc softmax.cu DEPS device_context
)
nv_library
(
cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context
)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
depthwise_conv SRCS depthwise_conv.cu DEPS device_context
)
nv_library
(
sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor
)
nv_library
(
context_project SRCS context_project.cc context_project.cu DEPS device_context math_function
)
...
...
paddle/operators/math/depthwise_conv.cu
0 → 100644
浏览文件 @
d0599511
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/depthwise_conv.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template
<
typename
T
>
__global__
void
KernelDepthwiseConv
(
const
int
nthreads
,
const
T
*
const
input_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
T
*
const
output_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
output_channels
/
output_height
/
output_width
;
const
int
c_out
=
(
index
/
output_height
/
output_width
)
%
output_channels
;
const
int
h_out
=
(
index
/
output_width
)
%
output_height
;
const
int
w_out
=
index
%
output_width
;
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
=
0
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
;
const
int
w_in_end
=
w_in_start
+
filter_width
;
const
int
in_offset
=
((
batch
*
input_channels
+
c_in
)
*
input_height
)
*
input_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
for
(
int
h_in
=
h_start
;
h_in
<
h_end
;
h_in
++
)
{
for
(
int
w_in
=
w_start
;
w_in
<
w_end
;
w_in
++
)
{
const
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
value
+=
weight
[(
h_in
-
h_in_start
)
*
filter_width
+
(
w_in
-
w_in_start
)]
*
input_data
[
offset
];
}
}
output_data
[
index
]
=
value
;
}
}
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template
<
typename
T
>
__global__
void
KernelDepthwiseConvInputGrad
(
const
int
nthreads
,
const
T
*
const
output_grad_data
,
const
T
*
const
filter_data
,
const
int
batch_size
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
T
*
const
input_grad_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
batch
=
index
/
input_channels
/
input_height
/
input_width
;
const
int
c_in
=
(
index
/
input_height
/
input_width
)
%
input_channels
;
const
int
h_in
=
(
index
/
input_width
)
%
input_height
;
const
int
w_in
=
index
%
input_width
;
const
int
c_out_start
=
c_in
*
filter_multiplier
;
int
h_out_start
=
(
h_in
-
filter_height
+
padding_height
+
stride_height
)
/
stride_height
;
h_out_start
=
0
>
h_out_start
?
0
:
h_out_start
;
int
h_out_end
=
(
h_in
+
padding_height
)
/
stride_height
;
h_out_end
=
output_height
-
1
<
h_out_end
?
output_height
-
1
:
h_out_end
;
int
w_out_start
=
(
w_in
-
filter_width
+
padding_width
+
stride_width
)
/
stride_width
;
w_out_start
=
0
>
w_out_start
?
0
:
w_out_start
;
int
w_out_end
=
(
w_in
+
padding_width
)
/
stride_width
;
w_out_end
=
output_width
-
1
<
w_out_end
?
output_width
-
1
:
w_out_end
;
T
value
=
0
;
for
(
int
c_out
=
c_out_start
;
c_out
<
c_out_start
+
filter_multiplier
;
c_out
++
)
{
for
(
int
h_out
=
h_out_start
;
h_out
<=
h_out_end
;
++
h_out
)
{
const
int
filter_h
=
h_in
+
padding_height
-
h_out
*
stride_height
;
for
(
int
w_out
=
w_out_start
;
w_out
<=
w_out_end
;
++
w_out
)
{
const
int
filter_w
=
w_in
+
padding_width
-
w_out
*
stride_width
;
const
int
filter_offset
=
c_out
*
filter_height
*
filter_width
+
filter_h
*
filter_width
+
filter_w
;
const
int
output_grad_offset
=
((
batch
*
output_channels
+
c_out
)
*
output_height
+
h_out
)
*
output_width
+
w_out
;
value
+=
output_grad_data
[
output_grad_offset
]
*
filter_data
[
filter_offset
];
}
}
}
input_grad_data
[
index
]
+=
value
;
}
}
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template
<
typename
T
>
__global__
void
KernelDepthwiseConvFilterGrad
(
const
int
nthreads
,
const
T
*
const
output_grad_data
,
const
T
*
const
input_data
,
const
int
num
,
const
int
output_channels
,
const
int
output_height
,
const
int
output_width
,
const
int
input_channels
,
const
int
input_height
,
const
int
input_width
,
const
int
filter_multiplier
,
const
int
filter_height
,
const
int
filter_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
T
*
const
filter_grad_data
)
{
int
index
=
(
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
w_out
=
index
%
output_width
;
const
int
h_out
=
(
index
/
output_width
)
%
output_height
;
const
int
c_out
=
(
index
/
output_width
/
output_height
)
%
output_channels
;
const
int
batch
=
(
index
/
output_width
/
output_height
/
output_channels
);
const
int
c_in
=
c_out
/
filter_multiplier
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
-
padding_height
+
h_out
*
stride_height
+
filter_height
;
const
int
w_in_end
=
-
padding_width
+
w_out
*
stride_width
+
filter_width
;
const
int
in_offset
=
(
batch
*
input_channels
+
c_in
)
*
input_height
*
input_width
;
T
*
addr_offset
=
filter_grad_data
+
c_out
*
filter_height
*
filter_width
;
const
int
h_end
=
h_in_end
<
input_height
?
h_in_end
:
input_height
;
const
int
w_end
=
w_in_end
<
input_width
?
w_in_end
:
input_width
;
const
int
h_start
=
h_in_start
>
0
?
h_in_start
:
0
;
const
int
w_start
=
w_in_start
>
0
?
w_in_start
:
0
;
for
(
int
h_in
=
h_start
;
h_in
<
h_end
;
h_in
++
)
{
for
(
int
w_in
=
w_start
;
w_in
<
w_end
;
w_in
++
)
{
const
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
const
T
diff_temp
=
output_grad_data
[
index
]
*
input_data
[
offset
];
T
*
addr
=
addr_offset
+
(
h_in
-
h_in_start
)
*
filter_width
+
(
w_in
-
w_in_start
);
paddle
::
platform
::
CudaAtomicAdd
(
addr
,
diff_temp
);
}
}
}
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
class
T
>
class
DepthwiseConvFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
ksize_height
=
filter
.
dims
()[
2
];
const
int
ksize_width
=
filter
.
dims
()[
3
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
filter_data
=
filter
.
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelDepthwiseConv
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
output_channels
/
input_channels
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_data
);
}
};
template
<
typename
T
>
class
DepthwiseConvInputGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
ksize_height
=
filter
.
dims
()[
2
];
const
int
ksize_width
=
filter
.
dims
()[
3
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
T
*
filter_data
=
filter
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelDepthwiseConvInputGrad
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
output_grad_data
,
filter_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
output_channels
/
input_channels
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
input_grad_data
);
}
};
template
<
typename
T
>
class
DepthwiseConvFilterGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
filter_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
ksize_height
=
filter_grad
->
dims
()[
2
];
const
int
ksize_width
=
filter_grad
->
dims
()[
3
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelDepthwiseConvFilterGrad
<
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
output_grad_data
,
input_data
,
batch_size
,
output_channels
,
output_height
,
output_width
,
input_channels
,
input_height
,
input_width
,
output_channels
/
input_channels
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
filter_grad_data
);
}
};
template
class
DepthwiseConvFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
DepthwiseConvFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
DepthwiseConvInputGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
DepthwiseConvInputGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
DepthwiseConvFilterGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
DepthwiseConvFilterGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/depthwise_conv.h
0 → 100644
浏览文件 @
d0599511
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/*
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
*/
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvInputGradFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
DepthwiseConvFilterGradFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
filter_grad
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
d0599511
...
...
@@ -1231,10 +1231,17 @@ def conv2d(input,
"""
if
stride
is
None
:
stride
=
[
1
,
1
]
helper
=
LayerHelper
(
'conv2d'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
num_channels
=
input
.
shape
[
1
]
l_type
=
'conv2d'
if
(
num_channels
==
groups
and
num_filters
%
num_channels
==
0
and
not
use_cudnn
):
l_type
=
'depthwise_conv2d'
helper
=
LayerHelper
(
l_type
,
**
locals
())
dtype
=
helper
.
input_dtype
()
if
groups
is
None
:
num_filter_channels
=
num_channels
else
:
...
...
@@ -1267,7 +1274,7 @@ def conv2d(input,
pre_bias
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
'conv2d'
,
type
=
l_type
,
inputs
=
{
'Input'
:
input
,
'Filter'
:
filter_param
,
...
...
python/paddle/v2/fluid/tests/test_conv2d_op.py
浏览文件 @
d0599511
...
...
@@ -241,6 +241,30 @@ class TestCUDNNWith1x1(TestWith1x1):
self
.
op_type
=
"conv2d"
class
TestDepthwiseConv
(
TestConv2dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
]
self
.
stride
=
[
2
,
2
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
self
.
groups
=
3
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
self
.
op_type
=
"depthwise_conv2d"
class
TestDepthwiseConv2
(
TestConv2dOp
):
def
init_test_case
(
self
):
self
.
pad
=
[
1
,
1
]
self
.
stride
=
[
1
,
1
]
self
.
input_size
=
[
2
,
3
,
5
,
5
]
# NCHW
self
.
groups
=
3
assert
np
.
mod
(
self
.
input_size
[
1
],
self
.
groups
)
==
0
f_c
=
self
.
input_size
[
1
]
/
self
.
groups
self
.
filter_size
=
[
6
,
f_c
,
3
,
3
]
self
.
op_type
=
"depthwise_conv2d"
# cudnn v5 does not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录