Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
73192bb1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73192bb1
编写于
8月 07, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a batch norm inference kernel.
上级
498e9de4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
216 addition
and
10 deletion
+216
-10
paddle/cuda/CMakeLists.txt
paddle/cuda/CMakeLists.txt
+1
-0
paddle/cuda/include/hl_batch_norm.h
paddle/cuda/include/hl_batch_norm.h
+50
-0
paddle/cuda/src/hl_batch_norm.cu
paddle/cuda/src/hl_batch_norm.cu
+68
-0
paddle/gserver/layers/CudnnBatchNormLayer.cpp
paddle/gserver/layers/CudnnBatchNormLayer.cpp
+27
-10
paddle/gserver/tests/test_BatchNorm.cpp
paddle/gserver/tests/test_BatchNorm.cpp
+70
-0
未找到文件。
paddle/cuda/CMakeLists.txt
浏览文件 @
73192bb1
...
...
@@ -39,6 +39,7 @@ set(CUDA_CU_SOURCES
src/hl_cuda_lstm.cu
src/hl_top_k.cu
src/hl_batch_transpose.cu
src/hl_batch_norm.cu
src/hl_cuda_sequence.cu
src/hl_table_apply.cu
)
...
...
paddle/cuda/include/hl_batch_norm.h
0 → 100644
浏览文件 @
73192bb1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_BATCH_NORM_H_
#define HL_BATCH_NORM_H_
#include "hl_base.h"
/**
* @brief batch norm inferece.
*
* @param[in] input input data.
* @param[out] output output data.
* @param[in] scale batch normalization scale parameter (in original
* paper scale is referred to as gamma).
* @param[in] bias batch normalization bias parameter (in original
* paper scale is referred to as beta).
* @param[in] estimatedMean
* @param[in] estimatedVar It is suggested that resultRunningMean,
* resultRunningVariance from the
* cudnnBatchNormalizationForwardTraining call
* accumulated during the training phase are passed
* as inputs here.
* @param[in] epsilon Epsilon value used in the batch
* normalization formula.
*/
extern
void
hl_batch_norm_cuda_inference
(
const
real
*
input
,
real
*
output
,
const
real
*
scale
,
const
real
*
bias
,
const
real
*
estimatedMean
,
const
real
*
estimatedVar
,
const
double
epsilon
,
size_t
batchSize
,
size_t
channel
,
size_t
height
,
size_t
width
);
#endif // HL_BATCH_NORM_H_
paddle/cuda/src/hl_batch_norm.cu
0 → 100644
浏览文件 @
73192bb1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_batch_norm.h"
__global__
void
batchNormInference
(
real
*
output
,
const
real
*
input
,
const
real
*
scale
,
const
real
*
bias
,
const
real
*
estimatedMean
,
const
real
*
estimatedVar
,
const
double
epsilon
,
size_t
batchSize
,
size_t
channel
,
size_t
height
,
size_t
width
)
{
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
num
=
channel
*
height
*
width
;
const
int
batch
=
blockIdx
.
y
;
for
(
int
i
=
tid
;
i
<
num
;
i
+=
blockDim
.
x
)
{
const
int
c
=
(
i
/
(
height
*
width
))
%
channel
;
const
int
id
=
batch
*
num
+
i
;
real
val
=
input
[
id
]
-
estimatedMean
[
c
];
val
/=
sqrt
(
estimatedVar
[
c
]
+
epsilon
);
val
*=
scale
[
c
];
val
+=
bias
[
c
];
output
[
id
]
=
val
;
}
}
void
hl_batch_norm_cuda_inference
(
const
real
*
input
,
real
*
output
,
const
real
*
scale
,
const
real
*
bias
,
const
real
*
estimatedMean
,
const
real
*
estimatedVar
,
const
double
epsilon
,
size_t
batchSize
,
size_t
channel
,
size_t
height
,
size_t
width
)
{
dim3
block
(
256
,
1
);
dim3
grid
(
1
,
batchSize
);
batchNormInference
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
output
,
input
,
scale
,
bias
,
estimatedMean
,
estimatedVar
,
epsilon
,
batchSize
,
channel
,
height
,
width
);
CHECK_SYNC
(
"hl_batch_norm_cuda_inference failed!"
);
}
paddle/gserver/layers/CudnnBatchNormLayer.cpp
浏览文件 @
73192bb1
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "CudnnBatchNormLayer.h"
#include "Layer.h"
#include "paddle/cuda/include/hl_batch_norm.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
...
...
@@ -79,16 +80,32 @@ void CudnnBatchNormLayer::forward(PassType passType) {
savedInvVar
);
}
else
{
// used movingMean and movingVar in testing
hl_batch_norm_forward_inference
(
ioDesc_
,
input
,
ioDesc_
,
output
,
bnParamDesc_
,
gamma
,
beta
,
movingMean
,
movingVar
,
EPS
);
if
(
batchSize
>
1024
)
{
// when batchSize is larger than 1024, there is a bug
// in cudnn library.
hl_batch_norm_cuda_inference
(
input
,
output
,
gamma
,
beta
,
movingMean
,
movingVar
,
EPS
,
batchSize
,
channels_
,
imageH_
,
imageW_
);
}
else
{
hl_batch_norm_forward_inference
(
ioDesc_
,
input
,
ioDesc_
,
output
,
bnParamDesc_
,
gamma
,
beta
,
movingMean
,
movingVar
,
EPS
);
}
}
/* activation */
{
...
...
paddle/gserver/tests/test_BatchNorm.cpp
浏览文件 @
73192bb1
...
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/utils/GlobalConstants.h"
#include "LayerGradUtil.h"
#include "paddle/cuda/include/hl_batch_norm.h"
#include "paddle/math/tests/TensorCheck.h"
#include "paddle/testing/TestUtil.h"
using
namespace
paddle
;
// NOLINT
...
...
@@ -117,6 +119,74 @@ TEST(Layer, batchNorm) {
CHECK_EQ
(
static_cast
<
int
>
(
convLayer
->
getOutputValue
()
->
getWidth
()),
576
);
}
#ifndef PADDLE_ONLY_CPU
void
batchNormInference
(
int
n
,
int
c
,
int
h
,
int
w
)
{
MatrixPtr
input
=
std
::
make_shared
<
GpuMatrix
>
(
n
,
c
*
h
*
w
);
MatrixPtr
cudnnOut
=
std
::
make_shared
<
GpuMatrix
>
(
n
,
c
*
h
*
w
);
MatrixPtr
cudaOut
=
std
::
make_shared
<
GpuMatrix
>
(
n
,
c
*
h
*
w
);
MatrixPtr
cudnnCheck
=
std
::
make_shared
<
CpuMatrix
>
(
n
,
c
*
h
*
w
);
MatrixPtr
cudaCheck
=
std
::
make_shared
<
CpuMatrix
>
(
n
,
c
*
h
*
w
);
input
->
randomizeUniform
();
cudnnOut
->
zeroMem
();
cudaOut
->
zeroMem
();
MatrixPtr
scale
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
c
);
scale
->
randomizeUniform
();
MatrixPtr
bias
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
c
);
bias
->
randomizeUniform
();
MatrixPtr
movingMean
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
c
);
movingMean
->
randomizeUniform
();
MatrixPtr
movingVar
=
std
::
make_shared
<
GpuMatrix
>
(
1
,
c
);
movingVar
->
randomizeUniform
();
movingVar
->
clip
(
0.01
,
50
);
hl_tensor_descriptor
ioDesc
;
hl_tensor_descriptor
bnDesc
;
hl_create_tensor_descriptor
(
&
ioDesc
);
hl_create_tensor_descriptor
(
&
bnDesc
);
hl_tensor_reshape
(
ioDesc
,
n
,
c
,
h
,
w
);
hl_tensor_reshape
(
bnDesc
,
1
,
c
,
1
,
1
);
double
EPS
=
1E-5
;
hl_batch_norm_forward_inference
(
ioDesc
,
input
->
getData
(),
ioDesc
,
cudnnOut
->
getData
(),
bnDesc
,
scale
->
getData
(),
bias
->
getData
(),
movingMean
->
getData
(),
movingVar
->
getData
(),
EPS
);
hl_batch_norm_cuda_inference
(
input
->
getData
(),
cudaOut
->
getData
(),
scale
->
getData
(),
bias
->
getData
(),
movingMean
->
getData
(),
movingVar
->
getData
(),
EPS
,
n
,
c
,
h
,
w
);
cudnnCheck
->
copyFrom
(
*
cudnnOut
);
cudaCheck
->
copyFrom
(
*
cudaOut
);
autotest
::
TensorCheckErr
(
*
cudnnCheck
,
*
cudaCheck
);
hl_destroy_tensor_descriptor
(
ioDesc
);
hl_destroy_tensor_descriptor
(
bnDesc
);
}
TEST
(
BatchNorm
,
Inference
)
{
batchNormInference
(
33
,
267
,
1
,
1
);
batchNormInference
(
19
,
105
,
4
,
4
);
}
#endif
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
initMain
(
argc
,
argv
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录