Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f8ef8c17
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f8ef8c17
编写于
6月 13, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the GPU version implementation of ImageExpandGrad function.
上级
152bd2f9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
103 addition
and
38 deletion
+103
-38
paddle/function/Im2ColOpGpu.cu
paddle/function/Im2ColOpGpu.cu
+88
-19
paddle/function/ImageExpandOp.cpp
paddle/function/ImageExpandOp.cpp
+1
-0
paddle/gserver/layers/BlockExpandLayer.cpp
paddle/gserver/layers/BlockExpandLayer.cpp
+14
-19
未找到文件。
paddle/function/Im2ColOpGpu.cu
浏览文件 @
f8ef8c17
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "Im2Col.h"
#include "Im2Col.h"
#include "hl_device_functions.cuh"
namespace
paddle
{
namespace
paddle
{
...
@@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData,
...
@@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData,
int
strideHeight
,
int
strideWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
)
{
int
outputHeight
,
int
outputWidth
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
swId
=
blockIdx
.
x
;
int
swId
=
blockIdx
.
x
;
int
shId
=
blockIdx
.
y
;
int
shId
=
blockIdx
.
y
;
for
(
int
channelId
=
threadIdx
.
z
;
for
(
int
channelId
=
threadIdx
.
z
;
channelId
<
inputChannels
;
channelId
<
inputChannels
;
channelId
+=
blockDim
.
z
)
{
channelId
+=
blockDim
.
z
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
+
channelId
*
inputHeight
*
inputWidth
;
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
int
colOffset
=
idx
+
idy
*
filterWidth
+
channelId
*
inputHeight
*
inputWidth
;
+
channelId
*
filterHeight
*
filterWidth
+
(
shId
*
outputWidth
+
swId
)
int
colOffset
=
idx
+
idy
*
filterWidth
*
(
inputChannels
*
filterHeight
*
filterWidth
);
+
channelId
*
filterHeight
*
filterWidth
+
(
shId
*
outputWidth
+
swId
)
if
(
idx
<
filterWidth
&&
idy
<
filterHeight
)
{
*
(
inputChannels
*
filterHeight
*
filterWidth
);
if
(
heightOffset
>=
inputHeight
||
heightOffset
<
0
||
widthOffset
>=
inputWidth
||
widthOffset
<
0
)
{
if
(
heightOffset
>=
inputHeight
||
heightOffset
<
0
||
colData
[
colOffset
]
=
T
(
0
);
widthOffset
>=
inputWidth
||
widthOffset
<
0
)
{
}
else
{
colData
[
colOffset
]
=
T
(
0
);
colData
[
colOffset
]
=
imData
[
imOffset
];
}
else
{
colData
[
colOffset
]
=
imData
[
imOffset
];
}
}
}
}
}
}
}
...
@@ -105,6 +105,41 @@ public:
...
@@ -105,6 +105,41 @@ public:
}
}
};
};
template
<
class
T
>
__global__
void
col2imOCF
(
T
*
imData
,
const
T
*
colData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
)
{
int
swId
=
blockIdx
.
x
;
int
shId
=
blockIdx
.
y
;
for
(
int
channelId
=
threadIdx
.
z
;
channelId
<
inputChannels
;
channelId
+=
blockDim
.
z
)
{
for
(
int
idy
=
threadIdx
.
y
;
idy
<
filterHeight
;
idy
+=
blockDim
.
y
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
filterWidth
;
idx
+=
blockDim
.
x
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
int
colOffset
=
idx
+
idy
*
filterWidth
+
channelId
*
filterHeight
*
filterWidth
+
(
shId
*
outputWidth
+
swId
)
*
(
inputChannels
*
filterHeight
*
filterWidth
);
if
(
heightOffset
>=
0
&&
heightOffset
<
inputHeight
&&
widthOffset
>=
0
&&
widthOffset
<
inputWidth
)
{
paddle
::
paddleAtomicAdd
(
imData
+
imOffset
,
colData
[
colOffset
]);
}
}
}
}
}
/*
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* colShape =
...
@@ -121,10 +156,44 @@ public:
...
@@ -121,10 +156,44 @@ public:
int
strideWidth
,
int
strideWidth
,
int
paddingHeight
,
int
paddingHeight
,
int
paddingWidth
)
{
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
3
];
int
filterWidth
=
colShape
[
4
];
int
outputHeight
=
colShape
[
0
];
int
outputWidth
=
colShape
[
1
];
int
blockDimX
=
0
;
int
blockDimY
=
0
;
if
(
filterHeight
<=
4
&&
filterWidth
<=
4
)
{
blockDimX
=
4
;
blockDimY
=
4
;
}
else
if
(
filterHeight
<=
8
&&
filterWidth
<=
8
)
{
blockDimX
=
8
;
blockDimY
=
8
;
}
else
if
(
filterHeight
<=
16
&&
filterWidth
<=
16
)
{
blockDimX
=
16
;
blockDimY
=
16
;
}
else
{
blockDimX
=
32
;
blockDimY
=
32
;
}
int
blockDimZ
=
1024
/
blockDimX
/
blockDimY
;
dim3
threads
(
blockDimX
,
blockDimY
,
std
::
min
(
blockDimZ
,
inputChannels
));
dim3
grid
(
outputWidth
,
outputHeight
);
col2imOCF
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
colData
,
inputChannels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Col2ImFunctor GPU failed"
);
}
}
};
};
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
template
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
}
// namespace paddle
}
// namespace paddle
paddle/function/ImageExpandOp.cpp
浏览文件 @
f8ef8c17
...
@@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
...
@@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward);
REGISTER_TYPED_FUNC
(
ImageExpandGrad
,
CPU
,
ImageExpandBackward
);
REGISTER_TYPED_FUNC
(
ImageExpandGrad
,
CPU
,
ImageExpandBackward
);
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC
(
ImageExpand
,
GPU
,
ImageExpandForward
);
REGISTER_TYPED_FUNC
(
ImageExpand
,
GPU
,
ImageExpandForward
);
REGISTER_TYPED_FUNC
(
ImageExpandGrad
,
GPU
,
ImageExpandBackward
);
#endif
#endif
}
// namespace paddle
}
// namespace paddle
paddle/gserver/layers/BlockExpandLayer.cpp
浏览文件 @
f8ef8c17
...
@@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
...
@@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
.
set
(
"strides"
,
strides
)
.
set
(
"strides"
,
strides
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"blocks"
,
blocks
));
.
set
(
"blocks"
,
blocks
));
if
(
!
useGpu_
)
{
createFunction
(
backward_
,
createFunction
(
backward_
,
"ImageExpandGrad"
,
"ImageExpandGrad"
,
FuncConfig
()
FuncConfig
()
.
set
(
"strides"
,
strides
)
.
set
(
"strides"
,
strides
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"blocks"
,
blocks
));
.
set
(
"blocks"
,
blocks
));
}
return
true
;
return
true
;
}
}
...
@@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) {
...
@@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) {
}
}
void
BlockExpandLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
void
BlockExpandLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
size_t
blockNum
=
outputH_
*
outputW_
;
size_t
blockSize
=
blockH_
*
blockW_
*
channels_
;
/* Calculate the input layers error */
/* Calculate the input layers error */
MatrixPtr
preGrad
=
inputLayers_
[
0
]
->
getOutputGrad
();
if
(
getInputGrad
(
0
))
{
if
(
!
preGrad
)
{
BufferArgs
inputs
;
return
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getOutputGrad
(),
outputShape_
);
outputs
.
addArg
(
*
getInputGrad
(
0
),
inputShape_
,
ADD_TO
);
backward_
[
0
]
->
calc
(
inputs
,
outputs
);
}
}
#if 0
if (useGpu_) {
if (useGpu_) {
MatrixPtr grad = getOutputGrad();
MatrixPtr grad = getOutputGrad();
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_);
...
@@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
...
@@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) {
1.0,
1.0,
1.0);
1.0);
}
}
}
else
{
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getOutputGrad
(),
outputShape_
);
outputs
.
addArg
(
*
getInputGrad
(
0
),
inputShape_
,
ADD_TO
);
backward_
[
0
]
->
calc
(
inputs
,
outputs
);
}
}
#endif
}
}
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录