Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d2d00106
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d2d00106
编写于
12月 15, 2016
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CrossMapNormalGradFunc
上级
9171ab0a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
190 addition
and
156 deletion
+190
-156
paddle/gserver/layers/NormProjectionLayer.cpp
paddle/gserver/layers/NormProjectionLayer.cpp
+29
-12
paddle/gserver/layers/NormProjectionLayer.h
paddle/gserver/layers/NormProjectionLayer.h
+4
-3
paddle/math/Function.h
paddle/math/Function.h
+1
-1
paddle/math/cross_map_normal_op.cpp
paddle/math/cross_map_normal_op.cpp
+96
-49
paddle/math/cross_map_normal_op.h
paddle/math/cross_map_normal_op.h
+12
-28
paddle/math/cross_map_normal_op_gpu.cu
paddle/math/cross_map_normal_op_gpu.cu
+16
-38
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+32
-25
未找到文件。
paddle/gserver/layers/NormProjectionLayer.cpp
浏览文件 @
d2d00106
...
@@ -13,10 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,10 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "NormProjectionLayer.h"
#include "NormProjectionLayer.h"
#include "paddle/math/cross_map_normal_op.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Stat.h"
#include "paddle/math/cross_map_normal_op.h"
#include "NormProjectionLayer.h"
namespace
paddle
{
namespace
paddle
{
size_t
CMRProjectionNormLayer
::
getSize
()
{
size_t
CMRProjectionNormLayer
::
getSize
()
{
...
@@ -48,13 +47,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
...
@@ -48,13 +47,23 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap,
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
if
(
useGpu_
)
{
if
(
useGpu_
)
{
normal
_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
forward
_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
GPU
));
FUNC_NAME
(
CrossMapNormal
,
GPU
));
}
else
{
}
else
{
normal
_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
forward
_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
CPU
));
FUNC_NAME
(
CrossMapNormal
,
CPU
));
}
}
normal_
->
init
(
forward_
->
init
(
FuncConfig
().
set
(
"size"
,
size_
).
set
(
"scale"
,
scale_
).
set
(
"pow"
,
pow_
));
if
(
useGpu_
)
{
backward_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormalGrad
,
GPU
));
}
else
{
backward_
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormalGrad
,
CPU
));
}
backward_
->
init
(
FuncConfig
().
set
(
"size"
,
size_
).
set
(
"scale"
,
scale_
).
set
(
"pow"
,
pow_
));
FuncConfig
().
set
(
"size"
,
size_
).
set
(
"scale"
,
scale_
).
set
(
"pow"
,
pow_
));
return
true
;
return
true
;
...
@@ -74,13 +83,13 @@ void CMRProjectionNormLayer::forward(PassType passType) {
...
@@ -74,13 +83,13 @@ void CMRProjectionNormLayer::forward(PassType passType) {
Matrix
::
resizeOrCreate
(
denoms_
,
batchSize
,
size
,
/* trans */
false
,
useGpu_
);
Matrix
::
resizeOrCreate
(
denoms_
,
batchSize
,
size
,
/* trans */
false
,
useGpu_
);
Dims
dims
{(
size_t
)
batchSize
,
dims_
=
{(
size_t
)
batchSize
,
(
size_t
)
channels_
,
(
size_t
)
channels_
,
(
size_t
)
imgSizeH_
,
(
size_t
)
imgSizeH_
,
(
size_t
)
imgSizeW_
};
(
size_t
)
imgSizeW_
};
normal
_
->
calc
(
forward
_
->
calc
(
{
Tensor
(
input
->
getData
(),
dims
)},
{
Tensor
(
input
->
getData
(),
dims
_
)},
{
Tensor
(
outV
->
getData
(),
dims
),
Tensor
(
denoms_
->
getData
(),
dims
)},
{
Tensor
(
outV
->
getData
(),
dims
_
),
Tensor
(
denoms_
->
getData
(),
dims_
)},
{});
{});
}
}
...
@@ -96,6 +105,13 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
...
@@ -96,6 +105,13 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
MatrixPtr
localOutV
=
getOutputValue
();
MatrixPtr
localOutV
=
getOutputValue
();
MatrixPtr
preOutV
=
inputLayers_
[
0
]
->
getOutputValue
();
MatrixPtr
preOutV
=
inputLayers_
[
0
]
->
getOutputValue
();
backward_
->
calc
({
Tensor
(
preOutV
->
getData
(),
dims_
),
Tensor
(
localOutV
->
getData
(),
dims_
),
Tensor
(
localGrad
->
getData
(),
dims_
),
Tensor
(
denoms_
->
getData
(),
dims_
)},
{
Tensor
(
preOutGrad
->
getData
(),
dims_
)},
{});
#if 0
if (useGpu_) {
if (useGpu_) {
CrossMapNormalGrad<DEVICE_TYPE_GPU> crossGrad;
CrossMapNormalGrad<DEVICE_TYPE_GPU> crossGrad;
crossGrad(dynamic_cast<GpuMatrix&>(*preOutGrad),
crossGrad(dynamic_cast<GpuMatrix&>(*preOutGrad),
...
@@ -123,5 +139,6 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
...
@@ -123,5 +139,6 @@ void CMRProjectionNormLayer::backward(const UpdateCallback& callback) {
scale_,
scale_,
pow_);
pow_);
}
}
#endif
}
}
}
// namespace paddle
}
// namespace paddle
paddle/gserver/layers/NormProjectionLayer.h
浏览文件 @
d2d00106
...
@@ -16,9 +16,8 @@ limitations under the License. */
...
@@ -16,9 +16,8 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "NormLayer.h"
#include "NormLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Function.h"
#include "paddle/math/Function.h"
#include
<vector>
#include
"paddle/math/Matrix.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -43,6 +42,8 @@ public:
...
@@ -43,6 +42,8 @@ public:
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
protected:
protected:
FunctionBase
*
normal_
;
Dims
dims_
;
FunctionBase
*
forward_
;
FunctionBase
*
backward_
;
};
};
}
// namespace paddle
}
// namespace paddle
paddle/math/Function.h
浏览文件 @
d2d00106
...
@@ -16,8 +16,8 @@ limitations under the License. */
...
@@ -16,8 +16,8 @@ limitations under the License. */
#include <map>
#include <map>
#include <vector>
#include <vector>
#include "paddle/utils/ClassRegistrar.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ClassRegistrar.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/math/cross_map_normal_op.cpp
浏览文件 @
d2d00106
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "cross_map_normal_op.h"
#include "cross_map_normal_op.h"
#include "paddle/math/Vector.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -56,66 +57,49 @@ void CrossMapNormal<DEVICE_TYPE_CPU>(real* outputs,
...
@@ -56,66 +57,49 @@ void CrossMapNormal<DEVICE_TYPE_CPU>(real* outputs,
}
}
template
<
>
template
<
>
void
CrossMapNormalGrad
<
DEVICE_TYPE_CPU
>::
operator
()(
CpuMatrix
&
inputsGrad
,
void
CrossMapNormalGrad
<
DEVICE_TYPE_CPU
>
(
real
*
inputsGrad
,
CpuMatrix
&
inputsValue
,
real
*
inputsValue
,
CpuMatrix
&
outputsGrad
,
real
*
outputsValue
,
CpuMatrix
&
outputsValue
,
real
*
outputsGrad
,
CpuMatrix
&
denoms
,
real
*
denoms
,
size_t
channels
,
size_t
numSamples
,
size_t
imgSizeH
,
size_t
channels
,
size_t
imgSizeW
,
size_t
height
,
size_t
sizeX
,
size_t
width
,
real
scale
,
size_t
size
,
real
pow
)
{
real
scale
,
CHECK
(
inputsGrad
.
isContiguous
());
real
pow
)
{
CHECK
(
outputsGrad
.
isContiguous
());
size_t
oneSample
=
channels
*
height
*
width
;
CHECK
(
denoms
.
isContiguous
());
CHECK
(
inputsValue
.
isContiguous
());
CHECK
(
outputsValue
.
isContiguous
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsGrad
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsGrad
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
denoms
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
inputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
inputsValue
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsValue
.
getWidth
());
size_t
numSample
=
inputsGrad
.
getHeight
();
size_t
numCols
=
inputsGrad
.
getWidth
();
size_t
imageSize
=
imgSizeH
*
imgSizeW
;
CHECK
(
imageSize
*
channels
==
numCols
);
std
::
function
<
CpuVector
(
real
*
,
size_t
)
>
oneImage
=
[
=
](
real
*
data
,
std
::
function
<
CpuVector
(
real
*
,
size_t
)
>
oneImage
=
[
=
](
real
*
data
,
size_t
offset
)
{
size_t
offset
)
{
return
CpuVector
(
imageSize
,
data
+
offset
);
return
CpuVector
(
height
*
width
,
data
+
offset
);
};
};
const
int
start
=
-
((
int
)
size
X
)
/
2
;
const
int
start
=
-
((
int
)
size
)
/
2
;
const
int
end
=
(
int
)
size
X
+
start
;
const
int
end
=
(
int
)
size
+
start
;
const
real
ratio
=
-
(
real
)
2
*
scale
*
pow
;
const
real
ratio
=
-
(
real
)
2
*
scale
*
pow
;
for
(
size_t
i
=
0
;
i
<
numSample
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
numSample
s
;
i
++
)
{
size_t
sOffset
=
i
*
numCols
;
size_t
sOffset
=
i
*
oneSample
;
real
*
inputGradData
=
inputsGrad
.
getData
()
+
sOffset
;
real
*
oneInputGrad
=
inputsGrad
+
sOffset
;
real
*
inputData
=
inputsValue
.
getData
()
+
sOffset
;
real
*
oneInputValue
=
inputsValue
+
sOffset
;
real
*
denomData
=
denoms
.
getData
()
+
sOffset
;
real
*
oneDenom
=
denoms
+
sOffset
;
real
*
o
utputGradData
=
outputsGrad
.
getData
()
+
sOffset
;
real
*
o
neOutputGrad
=
outputsGrad
+
sOffset
;
real
*
o
utputData
=
outputsValue
.
getData
()
+
sOffset
;
real
*
o
neOutputValue
=
outputsValue
+
sOffset
;
for
(
int
c
=
0
;
c
<
(
int
)
channels
;
c
++
)
{
for
(
int
c
=
0
;
c
<
(
int
)
channels
;
c
++
)
{
size_t
cOffset
=
c
*
imageSize
;
size_t
cOffset
=
c
*
height
*
width
;
CpuVector
inputGrad
=
oneImage
(
inputGradData
,
cOffset
);
CpuVector
inputGrad
=
oneImage
(
oneInputGrad
,
cOffset
);
CpuVector
inputValue
=
oneImage
(
inputData
,
cOffset
);
CpuVector
inputValue
=
oneImage
(
oneInputValue
,
cOffset
);
CpuVector
denom
=
oneImage
(
denomData
,
cOffset
);
CpuVector
denom
=
oneImage
(
oneDenom
,
cOffset
);
CpuVector
outputGrad
=
oneImage
(
o
utputGradData
,
cOffset
);
CpuVector
outputGrad
=
oneImage
(
o
neOutputGrad
,
cOffset
);
inputGrad
=
inputGrad
+
denom
.
pow
(
-
pow
)
*
outputGrad
;
inputGrad
=
inputGrad
+
denom
.
pow
(
-
pow
)
*
outputGrad
;
for
(
int
s
=
start
;
s
<
end
;
s
++
)
{
for
(
int
s
=
start
;
s
<
end
;
s
++
)
{
if
(
c
+
s
>=
0
&&
c
+
s
<
(
int
)
channels
)
{
if
(
c
+
s
>=
0
&&
c
+
s
<
(
int
)
channels
)
{
size_t
offset
=
(
c
+
s
)
*
imageSize
;
size_t
offset
=
(
c
+
s
)
*
height
*
width
;
CpuVector
output
=
oneImage
(
o
utputData
,
offset
);
CpuVector
output
=
oneImage
(
o
neOutputValue
,
offset
);
CpuVector
outputGrad
=
oneImage
(
o
utputGradData
,
offset
);
CpuVector
outputGrad
=
oneImage
(
o
neOutputGrad
,
offset
);
CpuVector
denom
=
oneImage
(
denomData
,
offset
);
CpuVector
denom
=
oneImage
(
oneDenom
,
offset
);
inputGrad
+=
((
outputGrad
*
output
*
ratio
)
/
denom
)
*
inputValue
;
inputGrad
+=
((
outputGrad
*
output
*
ratio
)
/
denom
)
*
inputValue
;
}
}
...
@@ -124,6 +108,11 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>::operator()(CpuMatrix& inputsGrad,
...
@@ -124,6 +108,11 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>::operator()(CpuMatrix& inputsGrad,
}
}
}
}
/**
* \param inputs[0] input value.
* \param outputs[0] output value.
* \param outputs[1] denoms.
*/
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
class
CrossMapNormalFunc
:
public
FunctionBase
{
class
CrossMapNormalFunc
:
public
FunctionBase
{
public:
public:
...
@@ -169,7 +158,65 @@ private:
...
@@ -169,7 +158,65 @@ private:
real
pow_
;
real
pow_
;
};
};
/**
* \param inputs[0] input value.
* \param inputs[1] output value.
* \param inputs[2] output grad.
* \param inputs[3] denoms.
* \param outputs[0] input grad.
*/
template
<
DeviceType
Device
>
class
CrossMapNormalGradFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
size_
=
config
.
get
<
size_t
>
(
"size"
);
scale_
=
config
.
get
<
real
>
(
"scale"
);
pow_
=
config
.
get
<
real
>
(
"pow"
);
}
void
calc
(
const
Arguments
&
inputs
,
const
Arguments
&
outputs
,
const
Arguments
&
inouts
)
override
{
CHECK_EQ
(
4
,
inputs
.
size
());
CHECK_EQ
(
1
,
outputs
.
size
());
CHECK_EQ
(
0
,
inouts
.
size
());
CHECK_EQ
(
inputs
[
0
].
dims_
.
size
(),
4
);
for
(
size_t
i
=
0
;
i
<
inputs
[
0
].
dims_
.
size
();
i
++
)
{
CHECK_EQ
(
inputs
[
0
].
dims_
[
i
],
inputs
[
1
].
dims_
[
i
]);
CHECK_EQ
(
inputs
[
0
].
dims_
[
i
],
inputs
[
2
].
dims_
[
i
]);
CHECK_EQ
(
inputs
[
0
].
dims_
[
i
],
inputs
[
3
].
dims_
[
i
]);
CHECK_EQ
(
inputs
[
0
].
dims_
[
i
],
outputs
[
0
].
dims_
[
i
]);
}
size_t
samples
=
inputs
[
0
].
dims_
[
0
];
size_t
channels
=
inputs
[
0
].
dims_
[
1
];
size_t
height
=
inputs
[
0
].
dims_
[
2
];
size_t
width
=
inputs
[
0
].
dims_
[
3
];
CrossMapNormalGrad
<
Device
>
(
outputs
[
0
].
getData
(),
inputs
[
0
].
getData
(),
inputs
[
1
].
getData
(),
inputs
[
2
].
getData
(),
inputs
[
3
].
getData
(),
samples
,
channels
,
height
,
width
,
size_
,
scale_
,
pow_
);
}
private:
size_t
size_
;
real
scale_
;
real
pow_
;
};
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
CPU
,
CrossMapNormalFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
CPU
,
CrossMapNormalFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
GPU
,
CrossMapNormalFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormal
,
GPU
,
CrossMapNormalFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormalGrad
,
CPU
,
CrossMapNormalGradFunc
);
REGISTER_TYPED_FUNC
(
CrossMapNormalGrad
,
GPU
,
CrossMapNormalGradFunc
);
}
// namespace paddle
}
// namespace paddle
paddle/math/cross_map_normal_op.h
浏览文件 @
d2d00106
...
@@ -15,7 +15,6 @@ limitations under the License. */
...
@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#pragma once
#include "Function.h"
#include "Function.h"
#include "paddle/math/Matrix.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -30,34 +29,19 @@ void CrossMapNormal(real* outputs,
...
@@ -30,34 +29,19 @@ void CrossMapNormal(real* outputs,
size_t
size
,
size_t
size
,
real
scale
,
real
scale
,
real
pow
);
real
pow
);
#if 0
template <DeviceType Device>
struct CrossMapNormal {
void operator()(typename MatrixT<Device>::type& outputs,
typename MatrixT<Device>::type& denoms,
typename MatrixT<Device>::type& inputs,
size_t channels,
size_t imgSizeH,
size_t imgSizeW,
size_t sizeX,
real scale,
real pow);
};
#endif
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
struct
CrossMapNormalGrad
{
void
CrossMapNormalGrad
(
real
*
inputsGrad
,
void
operator
()(
typename
MatrixT
<
Device
>::
type
&
inputsGrad
,
real
*
inputsValue
,
typename
MatrixT
<
Device
>::
type
&
inputsValue
,
real
*
outputsValue
,
typename
MatrixT
<
Device
>::
type
&
outputsGrad
,
real
*
outputsGrad
,
typename
MatrixT
<
Device
>::
type
&
outputsValue
,
real
*
denoms
,
typename
MatrixT
<
Device
>::
type
&
denoms
,
size_t
numSamples
,
size_t
channels
,
size_t
channels
,
size_t
imgSizeH
,
size_t
height
,
size_t
imgSizeW
,
size_t
width
,
size_t
sizeX
,
size_t
size
,
real
scale
,
real
scale
,
real
pow
);
real
pow
);
};
}
// namespace paddle
}
// namespace paddle
paddle/math/cross_map_normal_op_gpu.cu
浏览文件 @
d2d00106
...
@@ -131,48 +131,26 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data,
...
@@ -131,48 +131,26 @@ __global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data,
}
}
template
<
>
template
<
>
void
CrossMapNormalGrad
<
DEVICE_TYPE_GPU
>::
operator
()(
GpuMatrix
&
inputsGrad
,
void
CrossMapNormalGrad
<
DEVICE_TYPE_GPU
>
(
real
*
inputsGrad
,
GpuMatrix
&
inputsValue
,
real
*
inputsValue
,
GpuMatrix
&
outputsGrad
,
real
*
outputsValue
,
GpuMatrix
&
outputsValue
,
real
*
outputsGrad
,
GpuMatrix
&
denoms
,
real
*
denoms
,
size_t
channels
,
size_t
numSamples
,
size_t
imgSizeH
,
size_t
channels
,
size_t
imgSizeW
,
size_t
height
,
size_t
sizeX
,
size_t
width
,
real
scale
,
size_t
size
,
real
pow
)
{
real
scale
,
CHECK
(
inputsGrad
.
isContiguous
());
real
pow
)
{
CHECK
(
outputsGrad
.
isContiguous
());
size_t
imageSize
=
numSamples
*
height
*
width
;
CHECK
(
denoms
.
isContiguous
());
CHECK
(
inputsValue
.
isContiguous
());
CHECK
(
outputsValue
.
isContiguous
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsGrad
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsGrad
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
denoms
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
denoms
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
inputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
inputsValue
.
getWidth
());
CHECK_EQ
(
inputsGrad
.
getHeight
(),
outputsValue
.
getHeight
());
CHECK_EQ
(
inputsGrad
.
getWidth
(),
outputsValue
.
getWidth
());
size_t
numSample
=
inputsGrad
.
getHeight
();
size_t
numCols
=
inputsGrad
.
getWidth
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
numCols
);
size_t
imageSize
=
numSample
*
imgSizeH
*
imgSizeW
;
real
*
inputsGradData
=
inputsGrad
.
getData
();
real
*
inputsData
=
inputsValue
.
getData
();
real
*
denomsData
=
denoms
.
getData
();
real
*
outputsGradData
=
outputsGrad
.
getData
();
real
*
outputsData
=
outputsValue
.
getData
();
int
blockSize
=
1024
;
int
blockSize
=
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
int
gridSize
=
(
imageSize
+
1024
-
1
)
/
1024
;
KeCMRNormDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
KeCMRNormDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
imageSize
,
inputs
Data
,
outputsData
,
denomsData
,
outputsGradData
,
channels
,
(
imageSize
,
inputs
Value
,
outputsValue
,
denoms
,
outputsGrad
,
channels
,
imgSizeH
,
imgSizeW
,
sizeX
,
-
pow
,
2.0
f
*
pow
*
scale
,
inputsGradData
);
height
,
width
,
size
,
-
pow
,
2.0
f
*
pow
*
scale
,
inputsGrad
);
CHECK_SYNC
(
"
KeCMRNormDiff
"
);
CHECK_SYNC
(
"
CrossMapNormalGrad
"
);
}
}
}
// namespace paddle
}
// namespace paddle
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
d2d00106
...
@@ -19,12 +19,11 @@ limitations under the License. */
...
@@ -19,12 +19,11 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "TensorCheck.h"
#include "TensorCheck.h"
#include "paddle/gserver/tests/TestUtil.h"
#include "paddle/gserver/tests/TestUtil.h"
#include "paddle/math/Function.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Stat.h"
#include "TensorCheck.h"
#include "paddle/math/cross_map_normal_op.h"
#include "paddle/math/cross_map_normal_op.h"
#include "paddle/
math/Function
.h"
#include "paddle/
utils/Stat
.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Util.h"
using
namespace
paddle
;
// NOLINT
using
namespace
paddle
;
// NOLINT
...
@@ -1282,12 +1281,6 @@ void testCrossMapNormalFwd(
...
@@ -1282,12 +1281,6 @@ void testCrossMapNormalFwd(
inputsGpu
.
copyFrom
(
inputs
);
inputsGpu
.
copyFrom
(
inputs
);
outputsGpu
.
copyFrom
(
outputs
);
outputsGpu
.
copyFrom
(
outputs
);
#if 0
FuncConfig config;
config.set("size", (size_t)sizeX);
config.set("scale", scale);
config.set("pow", pow);
#endif
FunctionBase
*
cpu
=
FunctionBase
*
cpu
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
CPU
));
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormal
,
CPU
));
FunctionBase
*
gpu
=
FunctionBase
*
gpu
=
...
@@ -1311,22 +1304,6 @@ void testCrossMapNormalFwd(
...
@@ -1311,22 +1304,6 @@ void testCrossMapNormalFwd(
{
Tensor
(
inputsGpu
.
getData
(),
dims
)},
{
Tensor
(
inputsGpu
.
getData
(),
dims
)},
{
Tensor
(
outputsGpu
.
getData
(),
dims
),
Tensor
(
denomsGpu
.
getData
(),
dims
)},
{
Tensor
(
outputsGpu
.
getData
(),
dims
),
Tensor
(
denomsGpu
.
getData
(),
dims
)},
{});
{});
#if 0
CrossMapNormal<DEVICE_TYPE_CPU> cpuCross;
cpuCross(
outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow);
CrossMapNormal<DEVICE_TYPE_GPU> gpuCross;
gpuCross(outputsGpu,
denomsGpu,
inputsGpu,
channels,
imgSizeH,
imgSizeW,
sizeX,
scale,
pow);
#endif
TensorCheckErr
(
outputs
,
outputsGpu
);
TensorCheckErr
(
outputs
,
outputsGpu
);
TensorCheckErr
(
denoms
,
denomsGpu
);
TensorCheckErr
(
denoms
,
denomsGpu
);
...
@@ -1381,6 +1358,35 @@ void testCrossMapNormalBwd(
...
@@ -1381,6 +1358,35 @@ void testCrossMapNormalBwd(
outputsValueGpu
.
copyFrom
(
outputsValue
);
outputsValueGpu
.
copyFrom
(
outputsValue
);
inputsGradGpu
.
copyFrom
(
inputsGrad
);
inputsGradGpu
.
copyFrom
(
inputsGrad
);
FunctionBase
*
cpu
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormalGrad
,
CPU
));
FunctionBase
*
gpu
=
FunctionBase
::
funcRegistrar_
.
createByType
(
FUNC_NAME
(
CrossMapNormalGrad
,
GPU
));
cpu
->
init
(
FuncConfig
()
.
set
(
"size"
,
(
size_t
)
sizeX
)
.
set
(
"scale"
,
scale
)
.
set
(
"pow"
,
pow
));
gpu
->
init
(
FuncConfig
()
.
set
(
"size"
,
(
size_t
)
sizeX
)
.
set
(
"scale"
,
scale
)
.
set
(
"pow"
,
pow
));
Dims
dims
{
(
size_t
)
numSamples
,
(
size_t
)
channels
,
(
size_t
)
imgSizeH
,
(
size_t
)
imgSizeW
};
cpu
->
calc
({
Tensor
(
inputsValue
.
getData
(),
dims
),
Tensor
(
outputsValue
.
getData
(),
dims
),
Tensor
(
outputsGrad
.
getData
(),
dims
),
Tensor
(
denoms
.
getData
(),
dims
)},
{
Tensor
(
inputsGrad
.
getData
(),
dims
)},
{});
gpu
->
calc
({
Tensor
(
inputsValueGpu
.
getData
(),
dims
),
Tensor
(
outputsValueGpu
.
getData
(),
dims
),
Tensor
(
outputsGradGpu
.
getData
(),
dims
),
Tensor
(
denomsGpu
.
getData
(),
dims
)},
{
Tensor
(
inputsGradGpu
.
getData
(),
dims
)},
{});
#if 0
CrossMapNormalGrad<DEVICE_TYPE_CPU> cpuCross;
CrossMapNormalGrad<DEVICE_TYPE_CPU> cpuCross;
cpuCross(inputsGrad,
cpuCross(inputsGrad,
inputsValue,
inputsValue,
...
@@ -1406,6 +1412,7 @@ void testCrossMapNormalBwd(
...
@@ -1406,6 +1412,7 @@ void testCrossMapNormalBwd(
sizeX,
sizeX,
scale,
scale,
pow);
pow);
#endif
TensorCheckErr
(
inputsGrad
,
inputsGradGpu
);
TensorCheckErr
(
inputsGrad
,
inputsGradGpu
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录