Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
152bd2f9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
152bd2f9
编写于
6月 13, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the GPU version implementation of ImageExpand function.
上级
34362d93
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
156 addition
and
56 deletion
+156
-56
paddle/function/Im2Col.h
paddle/function/Im2Col.h
+3
-0
paddle/function/Im2ColOpGpu.cu
paddle/function/Im2ColOpGpu.cu
+130
-0
paddle/function/ImageExpandOp.cpp
paddle/function/ImageExpandOp.cpp
+3
-0
paddle/gserver/layers/BlockExpandLayer.cpp
paddle/gserver/layers/BlockExpandLayer.cpp
+20
-53
paddle/gserver/layers/BlockExpandLayer.h
paddle/gserver/layers/BlockExpandLayer.h
+0
-3
未找到文件。
paddle/function/Im2Col.h
浏览文件 @
152bd2f9
...
...
@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once
#include "TensorShape.h"
#include "TensorType.h"
namespace
paddle
{
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
...
...
paddle/function/Im2ColOpGpu.cu
0 → 100644
浏览文件 @
152bd2f9
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Im2Col.h"
namespace
paddle
{
template
<
class
T
>
__global__
void
im2colOCF
(
const
T
*
imData
,
T
*
colData
,
int
inputChannels
,
int
inputHeight
,
int
inputWidth
,
int
filterHeight
,
int
filterWidth
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
,
int
outputHeight
,
int
outputWidth
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
swId
=
blockIdx
.
x
;
int
shId
=
blockIdx
.
y
;
for
(
int
channelId
=
threadIdx
.
z
;
channelId
<
inputChannels
;
channelId
+=
blockDim
.
z
)
{
int
widthOffset
=
idx
+
swId
*
strideWidth
-
paddingWidth
;
int
heightOffset
=
idy
+
shId
*
strideHeight
-
paddingHeight
;
int
imOffset
=
widthOffset
+
heightOffset
*
inputWidth
+
channelId
*
inputHeight
*
inputWidth
;
int
colOffset
=
idx
+
idy
*
filterWidth
+
channelId
*
filterHeight
*
filterWidth
+
(
shId
*
outputWidth
+
swId
)
*
(
inputChannels
*
filterHeight
*
filterWidth
);
if
(
idx
<
filterWidth
&&
idy
<
filterHeight
)
{
if
(
heightOffset
>=
inputHeight
||
heightOffset
<
0
||
widthOffset
>=
inputWidth
||
widthOffset
<
0
)
{
colData
[
colOffset
]
=
T
(
0
);
}
else
{
colData
[
colOffset
]
=
imData
[
imOffset
];
}
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template
<
class
T
>
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
const
T
*
imData
,
const
TensorShape
&
imShape
,
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
int
inputChannels
=
imShape
[
0
];
int
inputHeight
=
imShape
[
1
];
int
inputWidth
=
imShape
[
2
];
int
filterHeight
=
colShape
[
3
];
int
filterWidth
=
colShape
[
4
];
int
outputHeight
=
colShape
[
0
];
int
outputWidth
=
colShape
[
1
];
int
blockDimX
=
0
;
int
blockDimY
=
0
;
if
(
filterHeight
<=
4
&&
filterWidth
<=
4
)
{
blockDimX
=
4
;
blockDimY
=
4
;
}
else
if
(
filterHeight
<=
8
&&
filterWidth
<=
8
)
{
blockDimX
=
8
;
blockDimY
=
8
;
}
else
if
(
filterHeight
<=
16
&&
filterWidth
<=
16
)
{
blockDimX
=
16
;
blockDimY
=
16
;
}
else
{
blockDimX
=
32
;
blockDimY
=
32
;
}
int
blockDimZ
=
1024
/
blockDimX
/
blockDimY
;
dim3
threads
(
blockDimX
,
blockDimY
,
std
::
min
(
blockDimZ
,
inputChannels
));
dim3
grid
(
outputWidth
,
outputHeight
);
im2colOCF
<
T
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
imData
,
colData
,
inputChannels
,
inputHeight
,
inputWidth
,
filterHeight
,
filterWidth
,
strideHeight
,
strideWidth
,
paddingHeight
,
paddingWidth
,
outputHeight
,
outputWidth
);
CHECK_SYNC
(
"Im2ColFunctor GPU failed"
);
}
};
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
*/
template
<
class
T
>
class
Col2ImFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
T
>
{
public:
void
operator
()(
T
*
imData
,
const
TensorShape
&
imShape
,
const
T
*
colData
,
const
TensorShape
&
colShape
,
int
strideHeight
,
int
strideWidth
,
int
paddingHeight
,
int
paddingWidth
)
{
}
};
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
float
>;
template
class
Im2ColFunctor
<
kOCF
,
DEVICE_TYPE_GPU
,
double
>;
}
// namespace paddle
paddle/function/ImageExpandOp.cpp
浏览文件 @
152bd2f9
...
...
@@ -291,5 +291,8 @@ public:
REGISTER_TYPED_FUNC
(
ImageExpand
,
CPU
,
ImageExpandForward
);
REGISTER_TYPED_FUNC
(
ImageExpandGrad
,
CPU
,
ImageExpandBackward
);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC
(
ImageExpand
,
GPU
,
ImageExpandForward
);
#endif
}
// namespace paddle
paddle/gserver/layers/BlockExpandLayer.cpp
浏览文件 @
152bd2f9
...
...
@@ -37,16 +37,16 @@ bool BlockExpandLayer::init(const LayerMap& layerMap,
imgSizeH_
=
blockConf
.
img_size_y
();
imgSizeW_
=
blockConf
.
img_size_x
();
std
::
vector
<
size_t
>
strides
=
{(
size_t
)
strideH_
,
(
size_t
)
strideW_
};
std
::
vector
<
size_t
>
paddings
=
{(
size_t
)
paddingH_
,
(
size_t
)
paddingW_
};
std
::
vector
<
size_t
>
blocks
=
{(
size_t
)
blockH_
,
(
size_t
)
blockW_
};
createFunction
(
forward_
,
"ImageExpand"
,
FuncConfig
()
.
set
(
"strides"
,
strides
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"blocks"
,
blocks
));
if
(
!
useGpu_
)
{
std
::
vector
<
size_t
>
strides
=
{(
size_t
)
strideH_
,
(
size_t
)
strideW_
};
std
::
vector
<
size_t
>
paddings
=
{(
size_t
)
paddingH_
,
(
size_t
)
paddingW_
};
std
::
vector
<
size_t
>
blocks
=
{(
size_t
)
blockH_
,
(
size_t
)
blockW_
};
createFunction
(
forward_
,
"ImageExpand"
,
FuncConfig
()
.
set
(
"strides"
,
strides
)
.
set
(
"paddings"
,
paddings
)
.
set
(
"blocks"
,
blocks
));
createFunction
(
backward_
,
"ImageExpandGrad"
,
FuncConfig
()
...
...
@@ -84,62 +84,29 @@ void BlockExpandLayer::forward(PassType passType) {
size_t
blockNum
=
getBlockNum
();
size_t
blockSize
=
blockH_
*
blockW_
*
channels_
;
resetOutput
(
blockNum
*
batchSize
,
blockSize
);
// TODO(hedaoyuan): After completing the GPU version of ImageExpand,
// refactor the following code.
Argument
&
out
=
getOutput
();
MatrixPtr
outV
=
getOutputValue
();
MatrixPtr
input
=
getPrev
(
0
)
->
getOutputValue
();
Matrix
::
resizeOrCreate
(
outVTrans_
,
blockSize
,
blockNum
,
false
,
useGpu_
);
// calculate output_.value
inputShape_
=
TensorShape
({
batchSize
,
channels_
,
imgSizeH_
,
imgSizeW_
});
outputShape_
=
TensorShape
({
batchSize
,
blockNum
,
blockSize
});
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getInputValue
(
0
),
inputShape_
);
outputs
.
addArg
(
*
getOutputValue
(),
outputShape_
,
ASSIGN_TO
);
forward_
[
0
]
->
calc
(
inputs
,
outputs
);
// calculate output_.sequenceStartPositions and output_.cpuSequenceDims
Argument
&
out
=
getOutput
();
ICpuGpuVector
::
resizeOrCreate
(
out
.
sequenceStartPositions
,
batchSize
+
1
,
false
);
IVector
::
resizeOrCreate
(
out
.
cpuSequenceDims
,
2
*
batchSize
,
false
);
int
*
start
=
out
.
sequenceStartPositions
->
getMutableData
(
false
);
int
*
dims
=
out
.
cpuSequenceDims
->
getData
();
for
(
size_t
i
=
0
;
i
<
batchSize
;
i
++
)
{
if
(
useGpu_
)
{
outVTrans_
->
zeroMem
();
/* expand each block as one row */
MatrixPtr
inputTmp
=
Matrix
::
create
(
input
->
getData
()
+
i
*
input
->
getWidth
(),
1
,
input
->
getWidth
(),
false
,
useGpu_
);
outVTrans_
->
convExpand
(
*
inputTmp
,
imgSizeH_
,
imgSizeW_
,
channels_
,
blockH_
,
blockW_
,
strideH_
,
strideW_
,
paddingH_
,
paddingW_
,
outputH_
,
outputW_
);
MatrixPtr
outVTmp
=
Matrix
::
create
(
outV
->
getData
()
+
i
*
blockNum
*
blockSize
,
blockNum
,
blockSize
,
false
,
useGpu_
);
outVTrans_
->
transpose
(
outVTmp
,
false
);
}
start
[
i
]
=
i
*
blockNum
;
dims
[
2
*
i
]
=
outputH_
;
dims
[
2
*
i
+
1
]
=
outputW_
;
}
start
[
batchSize
]
=
batchSize
*
blockNum
;
if
(
!
useGpu_
)
{
inputShape_
=
TensorShape
({
batchSize
,
channels_
,
imgSizeH_
,
imgSizeW_
});
outputShape_
=
TensorShape
({
batchSize
,
blockNum
,
blockSize
});
BufferArgs
inputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
getInputValue
(
0
),
inputShape_
);
outputs
.
addArg
(
*
getOutputValue
(),
outputShape_
,
ASSIGN_TO
);
forward_
[
0
]
->
calc
(
inputs
,
outputs
);
}
}
void
BlockExpandLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
...
...
paddle/gserver/layers/BlockExpandLayer.h
浏览文件 @
152bd2f9
...
...
@@ -50,9 +50,6 @@ protected:
size_t
blockH_
,
blockW_
,
strideH_
,
strideW_
,
paddingH_
,
paddingW_
;
size_t
imgSizeH_
,
imgSizeW_
,
outputH_
,
outputW_
,
channels_
;
/// auxiliary variable, which saves the transposed output value.
MatrixPtr
outVTrans_
;
TensorShape
inputShape_
;
TensorShape
outputShape_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录