Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2558c3f1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2558c3f1
编写于
2月 01, 2017
作者:
H
Haonan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revisions according to reviews
上级
b4c1d175
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
142 addition
and
58 deletion
+142
-58
paddle/cuda/include/hl_matrix.h
paddle/cuda/include/hl_matrix.h
+12
-0
paddle/cuda/include/stub/hl_matrix_stub.h
paddle/cuda/include/stub/hl_matrix_stub.h
+4
-0
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+25
-0
paddle/gserver/layers/RotateLayer.cpp
paddle/gserver/layers/RotateLayer.cpp
+42
-35
paddle/gserver/layers/RotateLayer.h
paddle/gserver/layers/RotateLayer.h
+9
-7
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+5
-2
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+13
-11
paddle/math/Matrix.h
paddle/math/Matrix.h
+12
-2
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+20
-1
未找到文件。
paddle/cuda/include/hl_matrix.h
浏览文件 @
2558c3f1
...
...
@@ -267,4 +267,16 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
const
int
dimN
,
real
scale
);
/**
* @brief Matrix rotation in 90 degrees
*
* @param[in] mat input matrix (M x N).
* @param[out] matRot output matrix (N x M).
* @param[in] dimM input matrix height.
* @param[in] dimN input matrix width.
* @param[in] clockWise rotation direction
*/
extern
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
);
#endif
/* HL_MATRIX_H_ */
paddle/cuda/include/stub/hl_matrix_stub.h
浏览文件 @
2558c3f1
...
...
@@ -106,4 +106,8 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
const
int
dimM
,
const
int
dimN
,
real
scale
)
{}
inline
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
);
#endif // HL_MATRIX_STUB_H_
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
2558c3f1
...
...
@@ -840,3 +840,28 @@ void hl_matrix_collect_shared_bias(real* B_d,
(
B_d
,
A_d
,
channel
,
dimM
,
dimN
,
dim
,
limit
,
scale
);
CHECK_SYNC
(
"hl_matrix_collect_shared_bias failed"
);
}
__global__
void
keMatrixRotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
dimM
*
dimN
)
{
int
i
=
idx
/
dimN
;
int
j
=
idx
%
dimN
;
if
(
clockWise
)
{
matRot
[
j
*
dimM
+
i
]
=
mat
[(
dimM
-
i
-
1
)
*
dimN
+
j
];
}
else
{
matRot
[
j
*
dimM
+
i
]
=
mat
[
i
*
dimN
+
(
dimN
-
j
-
1
)];
}
}
}
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
)
{
CHECK_NOTNULL
(
mat
);
CHECK_NOTNULL
(
matRot
);
const
int
threads
=
512
;
const
int
blocks
=
DIVUP
(
dimM
*
dimN
,
threads
);
keMatrixRotate
<<<
blocks
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
mat
,
matRot
,
dimM
,
dimN
,
clockWise
);
CHECK_SYNC
(
"hl_matrix_rotate failed"
);
}
paddle/gserver/layers/RotateLayer.cpp
浏览文件 @
2558c3f1
...
...
@@ -23,7 +23,8 @@ bool RotateLayer::init(const LayerMap& layerMap,
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
inputLayers_
.
size
(),
1UL
);
sampleHeight_
=
config_
.
height
();
height_
=
config_
.
height
();
width_
=
config_
.
width
();
return
true
;
}
...
...
@@ -32,26 +33,31 @@ void RotateLayer::forward(PassType passType) {
MatrixPtr
input
=
getInputValue
(
0
);
batchSize_
=
input
->
getHeight
();
sampleSize_
=
input
->
getWidth
();
sampleWidth_
=
sampleSize_
/
sampleHeight_
;
CHECK_EQ
(
sampleSize_
%
sampleHeight_
,
0
);
size_
=
input
->
getWidth
();
CHECK_GE
(
size_
,
height_
*
width_
);
CHECK_EQ
(
size_
%
(
height_
*
width_
),
0
)
<<
"The input's depth should be an int"
;
channels_
=
size_
/
(
height_
*
width_
);
resizeOutput
(
batchSize_
,
s
ampleS
ize_
);
resizeOutput
(
batchSize_
,
size_
);
MatrixPtr
outV
=
getOutputValue
();
for
(
int
b
=
0
;
b
<
batchSize_
;
b
++
)
{
MatrixPtr
inputSample
=
Matrix
::
create
(
input
->
getData
()
+
b
*
sampleSize_
,
sampleHeight_
,
sampleWidth_
,
false
,
useGpu_
);
MatrixPtr
outputSample
=
Matrix
::
create
(
outV
->
getData
()
+
b
*
sampleSize_
,
sampleWidth_
,
sampleHeight_
,
false
,
useGpu_
);
inputSample
->
rotate
(
outputSample
,
false
,
true
);
for
(
int
b
=
0
;
b
<
batchSize_
;
b
++
)
{
// for each input feat map
for
(
int
c
=
0
;
c
<
channels_
;
c
++
)
{
// for each feat channel
MatrixPtr
inputSample
=
Matrix
::
create
(
input
->
getData
()
+
b
*
size_
+
c
*
height_
*
width_
,
height_
,
width_
,
false
,
useGpu_
);
MatrixPtr
outputSample
=
Matrix
::
create
(
outV
->
getData
()
+
b
*
size_
+
c
*
height_
*
width_
,
width_
,
height_
,
false
,
useGpu_
);
inputSample
->
rotate
(
outputSample
,
false
,
true
/* clock-wise */
);
}
}
if
(
getInputGrad
(
0
))
{
...
...
@@ -69,23 +75,24 @@ void RotateLayer::backward(const UpdateCallback& callback) {
// the grad should be rotated in the reverse direction
MatrixPtr
preGrad
=
getInputGrad
(
0
);
for
(
int
b
=
0
;
b
<
batchSize_
;
b
++
)
{
MatrixPtr
inputSampleGrad
=
Matrix
::
create
(
preGrad
->
getData
()
+
b
*
sampleSize_
,
sampleHeight_
,
sampleWidth_
,
false
,
useGpu_
);
MatrixPtr
outputSampleGrad
=
Matrix
::
create
(
outputGrad
->
getData
()
+
b
*
sampleSize_
,
sampleWidth_
,
sampleHeight_
,
false
,
useGpu_
);
MatrixPtr
tmpGrad
=
Matrix
::
create
(
sampleHeight_
,
sampleWidth_
,
false
,
useGpu_
);
outputSampleGrad
->
rotate
(
tmpGrad
,
false
,
false
);
inputSampleGrad
->
add
(
*
tmpGrad
);
for
(
int
b
=
0
;
b
<
batchSize_
;
b
++
)
{
// for each input feat map
for
(
int
c
=
0
;
c
<
channels_
;
c
++
)
{
// for each feat channel
MatrixPtr
inputSampleGrad
=
Matrix
::
create
(
preGrad
->
getData
()
+
b
*
size_
+
c
*
height_
*
width_
,
height_
,
width_
,
false
,
useGpu_
);
MatrixPtr
outputSampleGrad
=
Matrix
::
create
(
outputGrad
->
getData
()
+
b
*
size_
+
c
*
height_
*
width_
,
width_
,
height_
,
false
,
useGpu_
);
MatrixPtr
tmpGrad
=
nullptr
;
outputSampleGrad
->
rotate
(
tmpGrad
,
true
,
false
/* anti clock-wise */
);
inputSampleGrad
->
add
(
*
tmpGrad
);
}
}
}
...
...
paddle/gserver/layers/RotateLayer.h
浏览文件 @
2558c3f1
...
...
@@ -19,12 +19,13 @@ limitations under the License. */
namespace
paddle
{
/**
* A layer for rotating an input sample (assume it's a matrix)
* The rotation is in clock-wise
* A layer for rotating a multi-channel feature map (M x N x C) in the spatial
* domain
* The rotation is 90 degrees in clock-wise
* \f[
* y(j,i
) = x(M-i-1,j
)
* y(j,i
,:) = x(M-i-1,j,:
)
* \f]
* where \f$x\f$ is (M x N
) input, and \f$y\f$ is (N x M
) output.
* where \f$x\f$ is (M x N
x C) input, and \f$y\f$ is (N x M x C
) output.
*
* The config file api is rotate_layer
*
...
...
@@ -41,9 +42,10 @@ public:
private:
int
batchSize_
;
int
sampleSize_
;
int
sampleHeight_
;
int
sampleWidth_
;
int
size_
;
int
height_
;
int
width_
;
int
channels_
;
};
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
2558c3f1
...
...
@@ -1320,9 +1320,12 @@ TEST(Layer, RotateLayer) {
TestConfig
config
;
config
.
biasSize
=
0
;
config
.
layerConfig
.
set_type
(
"rotate"
);
const
int
INPUT_SIZE
=
64
;
// height * width
const
int
INPUT_SIZE
=
64
;
// height * width * depth
const
int
HEIGHT
=
8
;
const
int
WIDTH
=
4
;
config
.
layerConfig
.
set_size
(
INPUT_SIZE
);
config
.
layerConfig
.
set_height
(
32
);
config
.
layerConfig
.
set_height
(
HEIGHT
);
config
.
layerConfig
.
set_width
(
WIDTH
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
INPUT_SIZE
,
0
});
config
.
layerConfig
.
add_inputs
();
...
...
paddle/math/Matrix.cpp
浏览文件 @
2558c3f1
...
...
@@ -388,6 +388,8 @@ void GpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
matTrans
=
std
::
make_shared
<
GpuMatrix
>
(
width_
,
height_
);
}
else
{
CHECK
(
matTrans
!=
NULL
);
CHECK_EQ
(
matTrans
->
getHeight
(),
width_
);
CHECK_EQ
(
matTrans
->
getWidth
(),
height_
);
}
real
*
dataTrans
=
matTrans
->
getData
();
real
*
data
=
getData
();
...
...
@@ -402,15 +404,13 @@ void GpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
matRot
=
std
::
make_shared
<
GpuMatrix
>
(
width_
,
height_
);
}
else
{
CHECK
(
matRot
!=
NULL
);
CHECK_EQ
(
matRot
->
getHeight
(),
width_
);
CHECK_EQ
(
matRot
->
getWidth
(),
height_
);
}
MatrixPtr
cpuMat
=
std
::
make_shared
<
CpuMatrix
>
(
height_
,
width_
);
cpuMat
->
copyFrom
(
*
this
);
MatrixPtr
cpuMatRot
=
std
::
make_shared
<
CpuMatrix
>
(
width_
,
height_
);
cpuMat
->
rotate
(
cpuMatRot
,
false
,
clockWise
);
matRot
->
copyFrom
(
*
cpuMatRot
);
real
*
dataRot
=
matRot
->
getData
();
real
*
data
=
getData
();
hl_matrix_rotate
(
data
,
dataRot
,
height_
,
width_
,
clockWise
);
}
MatrixPtr
GpuMatrix
::
getInverse
()
{
...
...
@@ -1723,6 +1723,8 @@ void CpuMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
matTrans
=
std
::
make_shared
<
CpuMatrix
>
(
width_
,
height_
);
}
else
{
CHECK
(
matTrans
!=
NULL
);
CHECK_EQ
(
matTrans
->
getHeight
(),
width_
);
CHECK_EQ
(
matTrans
->
getWidth
(),
height_
);
}
real
*
dataTrans
=
matTrans
->
getData
();
real
*
data
=
getData
();
...
...
@@ -1741,18 +1743,18 @@ void CpuMatrix::rotate(MatrixPtr& matRot, bool memAlloc, bool clockWise) {
matRot
=
std
::
make_shared
<
CpuMatrix
>
(
width_
,
height_
);
}
else
{
CHECK
(
matRot
!=
NULL
);
CHECK_EQ
(
matRot
->
getHeight
(),
width_
);
CHECK_EQ
(
matRot
->
getWidth
(),
height_
);
}
real
*
dataRot
=
matRot
->
getData
();
real
*
data
=
getData
();
int
lda
=
getStride
();
int
ldc
=
matRot
->
getStride
();
for
(
size_t
i
=
0
;
i
<
height_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
width_
;
j
++
)
{
if
(
clockWise
)
{
dataRot
[
j
*
ldc
+
i
]
=
data
[(
height_
-
i
-
1
)
*
lda
+
j
];
dataRot
[
j
*
height_
+
i
]
=
data
[(
height_
-
i
-
1
)
*
width_
+
j
];
}
else
{
dataRot
[
j
*
ldc
+
i
]
=
data
[
i
*
lda
+
(
width_
-
j
-
1
)];
dataRot
[
j
*
height_
+
i
]
=
data
[
i
*
width_
+
(
width_
-
j
-
1
)];
}
}
}
...
...
paddle/math/Matrix.h
浏览文件 @
2558c3f1
...
...
@@ -377,9 +377,19 @@ public:
}
/**
* @brief rotate clock-wise.
* @brief rotate 90 degrees in clock-wise if clockWise=true;
* otherwise rotate in anti clock-wise
* clock-wise:
* \f[
* y(j,i) = x(M-i-1,j)
* \f]
* anti clock-wise:
* \f[
* y(j,i) = x(i, N-1-j)
* \f]
* where \f$x\f$ is (M x N) input, and \f$y\f$ is (N x M) output.
*
* allocate mat
Trans
' memory outside, then set memAlloc as false;
* allocate mat
Rot
' memory outside, then set memAlloc as false;
* else set as true.
*/
virtual
void
rotate
(
MatrixPtr
&
matRot
,
bool
memAlloc
,
bool
clockWise
)
{
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
2558c3f1
...
...
@@ -176,11 +176,29 @@ void testMatrixTranspose(int height, int width) {
cpu
->
randomizeUniform
();
gpu
->
copyFrom
(
*
cpu
);
cpu
->
transpose
(
cpuT
,
false
);
gpu
->
transpose
(
gpuT
,
fals
e
);
gpu
->
transpose
(
gpuT
,
tru
e
);
TensorCheckEqual
(
*
cpuT
,
*
gpuT
);
}
void
testMatrixRotate
(
int
height
,
int
width
)
{
MatrixPtr
cpu
=
std
::
make_shared
<
CpuMatrix
>
(
height
,
width
);
MatrixPtr
gpu
=
std
::
make_shared
<
GpuMatrix
>
(
height
,
width
);
MatrixPtr
cpuR
=
std
::
make_shared
<
CpuMatrix
>
(
width
,
height
);
MatrixPtr
gpuR
=
std
::
make_shared
<
GpuMatrix
>
(
width
,
height
);
cpu
->
randomizeUniform
();
gpu
->
copyFrom
(
*
cpu
);
cpu
->
rotate
(
cpuR
,
false
,
true
);
gpu
->
rotate
(
gpuR
,
true
,
true
);
TensorCheckEqual
(
*
cpuR
,
*
gpuR
);
cpu
->
rotate
(
cpuR
,
true
,
false
);
gpu
->
rotate
(
gpuR
,
false
,
false
);
TensorCheckEqual
(
*
cpuR
,
*
gpuR
);
}
void
testMatrixInverse
(
int
height
)
{
MatrixPtr
cpu
=
std
::
make_shared
<
CpuMatrix
>
(
height
,
height
);
MatrixPtr
gpu
=
std
::
make_shared
<
GpuMatrix
>
(
height
,
height
);
...
...
@@ -215,6 +233,7 @@ TEST(Matrix, unary) {
testMatrixZeroAtOffset
(
height
,
width
);
testMatrixGetSum
(
height
,
width
);
testMatrixTranspose
(
height
,
width
);
testMatrixRotate
(
height
,
width
);
}
// inverse
testMatrixInverse
(
height
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录