Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e357f271
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e357f271
编写于
12月 13, 2016
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add GPU CrossMapNormal
上级
95035908
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
286 addition
and
53 deletion
+286
-53
paddle/math/cross_map_normal_op.cpp
paddle/math/cross_map_normal_op.cpp
+22
-20
paddle/math/cross_map_normal_op.h
paddle/math/cross_map_normal_op.h
+29
-8
paddle/math/cross_map_normal_op_gpu.cu
paddle/math/cross_map_normal_op_gpu.cu
+194
-0
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+41
-25
未找到文件。
paddle/math/cross_map_normal_op.cpp
浏览文件 @
e357f271
...
...
@@ -17,15 +17,16 @@ limitations under the License. */
namespace
paddle
{
// NCHW
void
CrossMapNormal
::
operator
()(
CpuMatrix
&
outputs
,
CpuMatrix
&
denoms
,
CpuMatrix
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
template
<
>
void
CrossMapNormal
<
DEVICE_TYPE_CPU
>::
operator
()(
CpuMatrix
&
outputs
,
CpuMatrix
&
denoms
,
CpuMatrix
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
outputs
.
isContiguous
());
CHECK
(
inputs
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
...
...
@@ -58,17 +59,18 @@ void CrossMapNormal::operator()(CpuMatrix& outputs,
outputs
=
inputs
*
denoms
.
pow
(
-
pow
);
}
void
CrossMapNormalGrad
::
operator
()(
CpuMatrix
&
inputsGrad
,
CpuMatrix
&
inputsValue
,
CpuMatrix
&
outputsGrad
,
CpuMatrix
&
outputsValue
,
CpuMatrix
&
denoms
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
template
<
>
void
CrossMapNormalGrad
<
DEVICE_TYPE_CPU
>::
operator
()(
CpuMatrix
&
inputsGrad
,
CpuMatrix
&
inputsValue
,
CpuMatrix
&
outputsGrad
,
CpuMatrix
&
outputsValue
,
CpuMatrix
&
denoms
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
inputsGrad
.
isContiguous
());
CHECK
(
outputsGrad
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
...
...
paddle/math/cross_map_normal_op.h
浏览文件 @
e357f271
...
...
@@ -18,10 +18,30 @@ limitations under the License. */
namespace
paddle
{
enum
DeviceType
{
DEVICE_TYPE_UNSPECIFIED
=
0
,
DEVICE_TYPE_CPU
=
1
,
DEVICE_TYPE_GPU
=
2
,
};
template
<
DeviceType
Device
>
struct
MatrixT
;
template
<
>
struct
MatrixT
<
DEVICE_TYPE_CPU
>
{
using
type
=
CpuMatrix
;
};
template
<
>
struct
MatrixT
<
DEVICE_TYPE_GPU
>
{
using
type
=
GpuMatrix
;
};
template
<
DeviceType
Device
>
struct
CrossMapNormal
{
void
operator
()(
CpuMatrix
&
outputs
,
CpuMatrix
&
denoms
,
CpuMatrix
&
inputs
,
void
operator
()(
typename
MatrixT
<
Device
>::
type
&
outputs
,
typename
MatrixT
<
Device
>::
type
&
denoms
,
typename
MatrixT
<
Device
>::
type
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -30,12 +50,13 @@ struct CrossMapNormal {
real
pow
);
};
template
<
DeviceType
Device
>
struct
CrossMapNormalGrad
{
void
operator
()(
CpuMatrix
&
inputsGrad
,
CpuMatrix
&
inputsValue
,
CpuMatrix
&
outputsGrad
,
CpuMatrix
&
outputsValue
,
CpuMatrix
&
denoms
,
void
operator
()(
typename
MatrixT
<
Device
>::
type
&
inputsGrad
,
typename
MatrixT
<
Device
>::
type
&
inputsValue
,
typename
MatrixT
<
Device
>::
type
&
outputsGrad
,
typename
MatrixT
<
Device
>::
type
&
outputsValue
,
typename
MatrixT
<
Device
>::
type
&
denoms
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
paddle/math/cross_map_normal_op_gpu.cu
0 → 100644
浏览文件 @
e357f271
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "cross_map_normal_op.h"
namespace
paddle
{
__global__
void
KeCMRNormFillScale
(
size_t
imageSize
,
const
real
*
in
,
real
*
scale
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
alpha
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
imageSize
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
n
=
idx
/
width
/
height
;
const
int
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
in
+=
offset
;
scale
+=
offset
;
const
int
step
=
height
*
width
;
const
int
pre_pad
=
(
size
-
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
real
accum
=
0
;
int
index
=
0
;
while
(
index
<
channels
+
post_pad
)
{
if
(
index
<
channels
)
{
accum
+=
in
[
index
*
step
]
*
in
[
index
*
step
];
}
if
(
index
>=
size
)
{
accum
-=
in
[(
index
-
size
)
*
step
]
*
in
[(
index
-
size
)
*
step
];
}
if
(
index
>=
post_pad
)
{
scale
[(
index
-
post_pad
)
*
step
]
=
1.
+
accum
*
alpha
;
}
++
index
;
}
}
}
__global__
void
KeCMRNormOutput
(
size_t
inputSize
,
const
real
*
in
,
const
real
*
scale
,
real
negative_beta
,
real
*
out
)
{
const
int
index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
index
<
inputSize
)
{
out
[
index
]
=
in
[
index
]
*
pow
(
scale
[
index
],
negative_beta
);
}
}
template
<
>
void
CrossMapNormal
<
DEVICE_TYPE_GPU
>::
operator
()(
GpuMatrix
&
outputs
,
GpuMatrix
&
denoms
,
GpuMatrix
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
outputs
.
isContiguous
());
CHECK
(
inputs
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK_EQ
(
outputs
.
getHeight
(),
inputs
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
inputs
.
getWidth
());
CHECK_EQ
(
outputs
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
denoms
.
getWidth
());
size_t
numSample
=
inputs
.
getHeight
();
size_t
numCols
=
inputs
.
getWidth
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
numCols
);
real
*
inputsData
=
inputs
.
getData
();
real
*
denomsData
=
denoms
.
getData
();
real
*
outputsData
=
outputs
.
getData
();
size_t
imageSize
=
numSample
*
imgSizeH
*
imgSizeW
;
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormFillScale
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
inputsData
,
denomsData
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
);
size_t
inputSize
=
numSample
*
imgSizeH
*
imgSizeW
*
channels
;
blockSize
=
1024
;
gridSize
=
(
inputSize
+
1024
-
1
)
/
1024
;
KeCMRNormOutput
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inputSize
,
inputsData
,
denomsData
,
-
pow
,
outputsData
);
CHECK_SYNC
(
"CrossMapNormalFwd"
);
}
__global__
void
KeCMRNormDiff
(
size_t
imageSize
,
const
real
*
bottom_data
,
const
real
*
top_data
,
const
real
*
scale
,
const
real
*
top_diff
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
negative_beta
,
real
cache_ratio
,
real
*
bottom_diff
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
imageSize
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
n
=
idx
/
width
/
height
;
const
int
offset
=
(
n
*
channels
*
height
+
h
)
*
width
+
w
;
bottom_data
+=
offset
;
top_data
+=
offset
;
scale
+=
offset
;
top_diff
+=
offset
;
bottom_diff
+=
offset
;
const
int
step
=
height
*
width
;
const
int
pre_pad
=
size
-
(
size
+
1
)
/
2
;
const
int
post_pad
=
size
-
pre_pad
-
1
;
int
index
=
0
;
real
accum
=
0
;
while
(
index
<
channels
+
post_pad
)
{
if
(
index
<
channels
)
{
accum
+=
top_diff
[
index
*
step
]
*
top_data
[
index
*
step
]
/
scale
[
index
*
step
];
}
if
(
index
>=
size
)
{
accum
-=
top_diff
[(
index
-
size
)
*
step
]
*
top_data
[(
index
-
size
)
*
step
]
/
scale
[(
index
-
size
)
*
step
];
}
if
(
index
>=
post_pad
)
{
bottom_diff
[(
index
-
post_pad
)
*
step
]
+=
top_diff
[(
index
-
post_pad
)
*
step
]
*
pow
(
scale
[(
index
-
post_pad
)
*
step
],
negative_beta
)
-
cache_ratio
*
bottom_data
[(
index
-
post_pad
)
*
step
]
*
accum
;
}
++
index
;
}
}
}
template
<
>
void
CrossMapNormalGrad
<
DEVICE_TYPE_GPU
>::
operator
()(
GpuMatrix
&
inputsGrad
,
GpuMatrix
&
inputsValue
,
GpuMatrix
&
outputsGrad
,
GpuMatrix
&
outputsValue
,
GpuMatrix
&
denoms
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
inputsGrad
.
isContiguous
());
CHECK
(
outputsGrad
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK
(
inputsValue
.
isContiguous
());
CHECK
(
outputsValue
.
isContiguous
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsGrad
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsGrad
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
denoms
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
inputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
inputsValue
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsValue
.
getWidth
());
size_t
numSample
=
inputsGrad
.
getHeight
();
size_t
numCols
=
inputsGrad
.
getWidth
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
numCols
);
size_t
imageSize
=
numSample
*
imgSizeH
*
imgSizeW
;
real
*
inputsGradData
=
inputsGrad
.
getData
();
real
*
inputsData
=
inputsValue
.
getData
();
real
*
denomsData
=
denoms
.
getData
();
real
*
outputsGradData
=
outputsGrad
.
getData
();
real
*
outputsData
=
outputsValue
.
getData
();
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
inputsData
,
outputsData
,
denomsData
,
outputsGradData
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
-
pow
,
2.0
f
*
pow
*
scale
,
inputsGradData
);
CHECK_SYNC
(
"KeCMRNormDiff"
);
}
}
// namespace paddle
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
e357f271
...
...
@@ -1280,11 +1280,25 @@ void testCrossMapNormalFwd(
inputsGpu
.
copyFrom
(
inputs
);
outputsGpu
.
copyFrom
(
outputs
);
CrossMapNormal
c
ross
;
cross
(
CrossMapNormal
<
DEVICE_TYPE_CPU
>
cpuC
ross
;
c
puC
ross
(
outputs
,
denoms
,
inputs
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
CrossMapNormal
<
DEVICE_TYPE_GPU
>
gpuCross
;
gpuCross
(
outputsGpu
,
denomsGpu
,
inputsGpu
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
#if 0
outputsGpu.crossMapNormalFwd(
inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow);
#endif
TensorCheckErr
(
outputs
,
outputsGpu
);
TensorCheckErr
(
denoms
,
denomsGpu
);
...
...
@@ -1339,29 +1353,31 @@ void testCrossMapNormalBwd(
outputsValueGpu
.
copyFrom
(
outputsValue
);
inputsGradGpu
.
copyFrom
(
inputsGrad
);
CrossMapNormalGrad
cross
;
cross
(
inputsGrad
,
inputsValue
,
outputsGrad
,
outputsValue
,
denoms
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
inputsGradGpu
.
crossMapNormalBwd
(
outputsGradGpu
,
denomsGpu
,
inputsValueGpu
,
outputsValueGpu
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
CrossMapNormalGrad
<
DEVICE_TYPE_CPU
>
cpuCross
;
cpuCross
(
inputsGrad
,
inputsValue
,
outputsGrad
,
outputsValue
,
denoms
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
CrossMapNormalGrad
<
DEVICE_TYPE_GPU
>
gpuCross
;
gpuCross
(
inputsGradGpu
,
inputsValueGpu
,
outputsGradGpu
,
outputsValueGpu
,
denomsGpu
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
,
pow
);
TensorCheckErr
(
inputsGrad
,
inputsGradGpu
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录