Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4ebb3eb7
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4ebb3eb7
编写于
12月 15, 2016
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
imporve Function
上级
ce1d98e0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
147 addition
and
92 deletion
+147
-92
paddle/gserver/layers/NormProjectionLayer.cpp
paddle/gserver/layers/NormProjectionLayer.cpp
+46
-14
paddle/gserver/layers/NormProjectionLayer.h
paddle/gserver/layers/NormProjectionLayer.h
+4
-0
paddle/math/Function.cpp
paddle/math/Function.cpp
+4
-2
paddle/math/Function.h
paddle/math/Function.h
+8
-6
paddle/math/cross_map_normal_op.cpp
paddle/math/cross_map_normal_op.cpp
+38
-37
paddle/math/cross_map_normal_op.h
paddle/math/cross_map_normal_op.h
+13
-0
paddle/math/cross_map_normal_op_gpu.cu
paddle/math/cross_map_normal_op_gpu.cu
+15
-31
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+19
-2
未找到文件。
paddle/gserver/layers/NormProjectionLayer.cpp
浏览文件 @
4ebb3eb7
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
#include "paddle/math/cross_map_normal_op.h"
#include "NormProjectionLayer.h"
namespace
paddle
{
...
...
@@ -45,6 +46,16 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
/* the size of inputs for norm-layer is 1 */
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
if
(
useGpu_
)
{
normal_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
GPU
));
}
else
{
normal_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
CPU
));
}
normal_
->
init
(
FuncConfig
().
set
(
"size"
,
size_
).
set
(
"scale"
,
scale_
).
set
(
"pow"
,
pow_
));
return
true
;
}
...
...
@@ -62,10 +73,14 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix
::
resizeOrCreate
(
denoms_
,
batchSize
,
size
,
/* trans */
false
,
useGpu_
);
denoms_
->
zeroMem
();
outV
->
crossMapNormalFwd
(
*
input
,
imgSizeH_
,
imgSizeW_
,
*
denoms_
,
channels_
,
size_
,
scale_
,
pow_
);
Dims
dims
{(
size_t
)
batchSize
,
(
size_t
)
channels_
,
(
size_t
)
imgSizeH_
,
(
size_t
)
imgSizeW_
};
normal_
->
calc
(
{
Tensor
(
input
->
getData
(),
dims
)},
{
Tensor
(
outV
->
getData
(),
dims
),
Tensor
(
denoms_
->
getData
(),
dims
)},
{});
}
void
CMRProjectionNormLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
...
...
@@ -80,15 +95,32 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr
localOutV
=
getOutputValue
();
MatrixPtr
preOutV
=
inputLayers_
[
0
]
->
getOutputValue
();
preOutGrad
->
crossMapNormalBwd
(
*
localGrad
,
*
denoms_
,
*
preOutV
,
*
localOutV
,
channels_
,
imgSizeH_
,
imgSizeW_
,
size_
,
scale_
,
pow_
);
if
(
useGpu_
)
{
CrossMapNormalGrad
<
DEVICE_TYPE_GPU
>
crossGrad
;
crossGrad
(
dynamic_cast
<
GpuMatrix
&>
(
*
preOutGrad
),
dynamic_cast
<
GpuMatrix
&>
(
*
preOutV
),
dynamic_cast
<
GpuMatrix
&>
(
*
localGrad
),
dynamic_cast
<
GpuMatrix
&>
(
*
localOutV
),
dynamic_cast
<
GpuMatrix
&>
(
*
denoms_
),
channels_
,
imgSizeH_
,
imgSizeW_
,
size_
,
scale_
,
pow_
);
}
else
{
CrossMapNormalGrad
<
DEVICE_TYPE_CPU
>
crossGrad
;
crossGrad
(
dynamic_cast
<
CpuMatrix
&>
(
*
preOutGrad
),
dynamic_cast
<
CpuMatrix
&>
(
*
preOutV
),
dynamic_cast
<
CpuMatrix
&>
(
*
localGrad
),
dynamic_cast
<
CpuMatrix
&>
(
*
localOutV
),
dynamic_cast
<
CpuMatrix
&>
(
*
denoms_
),
channels_
,
imgSizeH_
,
imgSizeW_
,
size_
,
scale_
,
pow_
);
}
}
}
// namespace paddle
paddle/gserver/layers/NormProjectionLayer.h
浏览文件 @
4ebb3eb7
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "NormLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Function.h"
#include <vector>
namespace
paddle
{
...
...
@@ -39,5 +40,8 @@ public:
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
protected:
FunctionBase
*
normal_
;
};
}
// namespace paddle
paddle/math/Function.cpp
浏览文件 @
4ebb3eb7
...
...
@@ -31,15 +31,17 @@ real FuncConfig::get<real>(const std::string& key) const {
}
template
<
>
void
FuncConfig
::
set
<
size_t
>
(
const
std
::
string
&
key
,
size_t
v
)
{
FuncConfig
&
FuncConfig
::
set
<
size_t
>
(
const
std
::
string
&
key
,
size_t
v
)
{
CHECK
(
valueMap_
.
count
(
key
)
==
0
)
<<
"Duplicated value: "
<<
key
;
valueMap_
[
key
].
s
=
v
;
return
*
this
;
}
template
<
>
void
FuncConfig
::
set
<
real
>
(
const
std
::
string
&
key
,
real
v
)
{
FuncConfig
&
FuncConfig
::
set
<
real
>
(
const
std
::
string
&
key
,
real
v
)
{
CHECK
(
valueMap_
.
count
(
key
)
==
0
)
<<
"Duplicated value: "
<<
key
;
valueMap_
[
key
].
r
=
v
;
return
*
this
;
}
ClassRegistrar
<
FunctionBase
>
FunctionBase
::
funcRegistrar_
;
...
...
paddle/math/Function.h
浏览文件 @
4ebb3eb7
...
...
@@ -46,6 +46,8 @@ class Tensor {
public:
Tensor
(
real
*
data
,
const
Dims
&
dim
)
:
buf_
(
data
),
dims_
(
dim
)
{}
real
*
getData
()
const
{
return
buf_
;
}
real
*
buf_
;
Dims
dims_
;
};
...
...
@@ -63,7 +65,7 @@ public:
T
get
(
const
std
::
string
&
key
)
const
;
template
<
typename
T
>
void
set
(
const
std
::
string
&
key
,
T
v
);
FuncConfig
&
set
(
const
std
::
string
&
key
,
T
v
);
protected:
std
::
map
<
std
::
string
,
value
>
valueMap_
;
...
...
@@ -84,11 +86,11 @@ public:
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
#define REGISTER_TYPED_FUNC(typeName, deviceName, className) \
static InitFunction __reg_type_##typeName
([]() {
\
FunctionBase::funcRegistrar_ \
.registerClass<className<DEVICE_TYPE_##deviceName>>( \
FUNC_NAME(typeName, deviceName)); \
#define REGISTER_TYPED_FUNC(typeName, deviceName, className)
\
static InitFunction __reg_type_##typeName
##deviceName([]() {
\
FunctionBase::funcRegistrar_
\
.registerClass<className<DEVICE_TYPE_##deviceName>>(
\
FUNC_NAME(typeName, deviceName));
\
})
}
// namespace paddle
paddle/math/cross_map_normal_op.cpp
浏览文件 @
4ebb3eb7
...
...
@@ -18,45 +18,41 @@ namespace paddle {
// NCHW
template
<
>
void
CrossMapNormal
<
DEVICE_TYPE_CPU
>::
operator
()(
CpuMatrix
&
outputs
,
CpuMatrix
&
denoms
,
CpuMatrix
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
outputs
.
isContiguous
());
CHECK
(
inputs
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK_EQ
(
outputs
.
getHeight
(),
inputs
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
inputs
.
getWidth
());
CHECK_EQ
(
outputs
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
denoms
.
getWidth
());
size_t
numSample
=
inputs
.
getHeight
();
size_t
numCols
=
inputs
.
getWidth
();
size_t
imageSize
=
imgSizeH
*
imgSizeW
;
CHECK
(
imageSize
*
channels
==
numCols
);
denoms
=
denoms
.
constant
(
1.0
);
const
int
start
=
-
((
int
)
sizeX
-
1
)
/
2
;
const
int
end
=
(
int
)
sizeX
+
start
;
for
(
size_t
i
=
0
;
i
<
numSample
;
i
++
)
{
real
*
denomsData
=
denoms
.
getData
()
+
i
*
numCols
;
real
*
inputData
=
inputs
.
getData
()
+
i
*
numCols
;
void
CrossMapNormal
<
DEVICE_TYPE_CPU
>
(
real
*
outputs
,
real
*
denoms
,
real
*
inputs
,
size_t
numSamples
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
scale
,
real
pow
)
{
size_t
oneImage
=
height
*
width
;
size_t
oneSample
=
channels
*
oneImage
;
CpuVector
outputsV
(
numSamples
*
oneSample
,
outputs
);
CpuVector
inputsV
(
numSamples
*
oneSample
,
inputs
);
CpuVector
denomsV
(
numSamples
*
oneSample
,
denoms
);
denomsV
=
denomsV
.
constant
(
1.0
);
const
int
start
=
-
((
int
)
size
-
1
)
/
2
;
const
int
end
=
(
int
)
size
+
start
;
for
(
size_t
i
=
0
;
i
<
numSamples
;
i
++
)
{
real
*
oneDenom
=
denoms
+
i
*
oneSample
;
real
*
oneInput
=
inputs
+
i
*
oneSample
;
for
(
int
c
=
0
;
c
<
(
int
)
channels
;
c
++
)
{
CpuVector
denom
(
imageSize
,
denomsData
+
c
*
imageSiz
e
);
CpuVector
denom
(
oneImage
,
oneDenom
+
c
*
oneImag
e
);
for
(
int
s
=
start
;
s
<
end
;
s
++
)
{
if
(
c
+
s
>=
0
&&
c
+
s
<
(
int
)
channels
)
{
CpuVector
input
(
imageSize
,
inputData
+
(
c
+
s
)
*
imageSiz
e
);
CpuVector
input
(
oneImage
,
oneInput
+
(
c
+
s
)
*
oneImag
e
);
denom
+=
input
.
square
()
*
scale
;
}
}
}
}
outputs
=
inputs
*
denoms
.
pow
(
-
pow
);
outputsV
=
inputsV
*
denomsV
.
pow
(
-
pow
);
}
template
<
>
...
...
@@ -154,13 +150,17 @@ public:
size_t
channels
=
inputs
[
0
].
dims_
[
1
];
size_t
height
=
inputs
[
0
].
dims_
[
2
];
size_t
width
=
inputs
[
0
].
dims_
[
3
];
size_t
imageSize
=
channels
*
height
*
width
;
CpuMatrix
input
(
inputs
[
0
].
buf_
,
samples
,
imageSize
);
CpuMatrix
output
(
outputs
[
0
].
buf_
,
samples
,
imageSize
);
CpuMatrix
denom
(
outputs
[
1
].
buf_
,
samples
,
imageSize
);
CrossMapNormal
<
Device
>
cross
;
cross
(
output
,
denom
,
input
,
channels
,
height
,
width
,
size_
,
scale_
,
pow_
);
CrossMapNormal
<
Device
>
(
outputs
[
0
].
getData
(),
outputs
[
1
].
getData
(),
inputs
[
0
].
getData
(),
samples
,
channels
,
height
,
width
,
size_
,
scale_
,
pow_
);
}
private:
...
...
@@ -170,5 +170,6 @@ private:
};
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
CPU
,
CrossMapNormalFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
GPU
,
CrossMapNormalFunc
);
}
// namespace paddle
paddle/math/cross_map_normal_op.h
浏览文件 @
4ebb3eb7
...
...
@@ -19,6 +19,18 @@ limitations under the License. */
namespace
paddle
{
template
<
DeviceType
Device
>
void
CrossMapNormal
(
real
*
outputs
,
real
*
denoms
,
real
*
inputs
,
size_t
numSamples
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
scale
,
real
pow
);
#if 0
template <DeviceType Device>
struct CrossMapNormal {
void operator()(typename MatrixT<Device>::type& outputs,
...
...
@@ -31,6 +43,7 @@ struct CrossMapNormal {
real scale,
real pow);
};
#endif
template
<
DeviceType
Device
>
struct
CrossMapNormalGrad
{
...
...
paddle/math/cross_map_normal_op_gpu.cu
浏览文件 @
4ebb3eb7
...
...
@@ -61,45 +61,29 @@ __global__ void KeCMRNormOutput(size_t inputSize, const real* in,
}
template
<
>
void
CrossMapNormal
<
DEVICE_TYPE_GPU
>::
operator
()(
GpuMatrix
&
outputs
,
GpuMatrix
&
denoms
,
GpuMatrix
&
inputs
,
size_t
channels
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
sizeX
,
real
scale
,
real
pow
)
{
CHECK
(
outputs
.
isContiguous
());
CHECK
(
inputs
.
isContiguous
());
CHECK
(
denoms
.
isContiguous
());
CHECK_EQ
(
outputs
.
getHeight
(),
inputs
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
inputs
.
getWidth
());
CHECK_EQ
(
outputs
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
outputs
.
getWidth
(),
denoms
.
getWidth
());
size_t
numSample
=
inputs
.
getHeight
();
size_t
numCols
=
inputs
.
getWidth
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
numCols
);
real
*
inputsData
=
inputs
.
getData
();
real
*
denomsData
=
denoms
.
getData
();
real
*
outputsData
=
outputs
.
getData
();
size_t
imageSize
=
numSample
*
imgSizeH
*
imgSizeW
;
void
CrossMapNormal
<
DEVICE_TYPE_GPU
>
(
real
*
outputs
,
real
*
denoms
,
real
*
inputs
,
size_t
numSamples
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
size
,
real
scale
,
real
pow
)
{
size_t
imageSize
=
numSamples
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormFillScale
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
inputsData
,
denomsData
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
scale
);
(
imageSize
,
inputs
,
denoms
,
channels
,
height
,
width
,
size
,
scale
);
size_t
inputSize
=
numSample
*
imgSizeH
*
imgSizeW
*
channels
;
size_t
inputSize
=
numSample
s
*
height
*
width
*
channels
;
blockSize
=
1024
;
gridSize
=
(
inputSize
+
1024
-
1
)
/
1024
;
KeCMRNormOutput
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inputSize
,
inputs
Data
,
denomsData
,
-
pow
,
outputsData
);
(
inputSize
,
inputs
,
denoms
,
-
pow
,
outputs
);
CHECK_SYNC
(
"CrossMapNormal
Fwd
"
);
CHECK_SYNC
(
"CrossMapNormal"
);
}
__global__
void
KeCMRNormDiff
(
size_t
imageSize
,
const
real
*
bottom_data
,
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
4ebb3eb7
...
...
@@ -1281,24 +1281,40 @@ void testCrossMapNormalFwd(
inputsGpu
.
copyFrom
(
inputs
);
outputsGpu
.
copyFrom
(
outputs
);
#if 0
FuncConfig config;
config.set("size", (size_t)sizeX);
config.set("scale", scale);
config.set("pow", pow);
#endif
FunctionBase
*
cpu
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
CPU
));
cpu
->
init
(
config
);
FunctionBase
*
gpu
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
GPU
));
cpu
->
init
(
FuncConfig
()
.
set
(
"size"
,
(
size_t
)
sizeX
)
.
set
(
"scale"
,
scale
)
.
set
(
"pow"
,
pow
));
gpu
->
init
(
FuncConfig
()
.
set
(
"size"
,
(
size_t
)
sizeX
)
.
set
(
"scale"
,
scale
)
.
set
(
"pow"
,
pow
));
Dims
dims
{
(
size_t
)
numSamples
,
(
size_t
)
channels
,
(
size_t
)
imgSizeH
,
(
size_t
)
imgSizeW
};
cpu
->
calc
({
Tensor
(
inputs
.
getData
(),
dims
)},
{
Tensor
(
outputs
.
getData
(),
dims
),
Tensor
(
denoms
.
getData
(),
dims
)},
{});
gpu
->
calc
(
{
Tensor
(
inputsGpu
.
getData
(),
dims
)},
{
Tensor
(
outputsGpu
.
getData
(),
dims
),
Tensor
(
denomsGpu
.
getData
(),
dims
)},
{});
#if 0
CrossMapNormal<DEVICE_TYPE_CPU> cpuCross;
cpuCross(
outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow);
#endif
CrossMapNormal<DEVICE_TYPE_GPU> gpuCross;
gpuCross(outputsGpu,
denomsGpu,
...
...
@@ -1309,6 +1325,7 @@ void testCrossMapNormalFwd(
sizeX,
scale,
pow);
#endif
TensorCheckErr
(
outputs
,
outputsGpu
);
TensorCheckErr
(
denoms
,
denomsGpu
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录