Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ec6b13db
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec6b13db
编写于
12月 29, 2016
作者:
X
xutianbing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean up unused code.
上级
ea4d08da
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
53 addition
and
747 deletion
+53
-747
paddle/cuda/include/hl_sequence.h
paddle/cuda/include/hl_sequence.h
+0
-72
paddle/cuda/include/stub/hl_sequence_stub.h
paddle/cuda/include/stub/hl_sequence_stub.h
+0
-29
paddle/cuda/src/hl_cuda_sequence.cu
paddle/cuda/src/hl_cuda_sequence.cu
+0
-252
paddle/function/ContextProjectionOpGpu.cu
paddle/function/ContextProjectionOpGpu.cu
+53
-11
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+0
-169
paddle/math/Matrix.h
paddle/math/Matrix.h
+0
-72
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+0
-142
未找到文件。
paddle/cuda/include/hl_sequence.h
浏览文件 @
ec6b13db
...
...
@@ -48,78 +48,6 @@ extern void hl_max_sequence_forward(real* input,
extern
void
hl_max_sequence_backward
(
real
*
outputGrad
,
int
*
index
,
real
*
inputGrad
,
int
numSequences
,
int
dim
);
/**
* @brief Context projection forward.
*
* @param[in] input input sequence.
* @param[in] sequence sequence index.
* @param[in] weightData padding data.
* @param[out] output output sequence.
* @param[in] numSequences number of sequences.
* @param[in] inputDim input sequence dimension.
* @param[in] contextLength context length.
* @param[in] contextStart context start.
* @param[in] beginPad number of extra timesteps added at the
* beginning.
* @param[in] isPadding trainable padding.
*
*/
extern
void
hl_context_projection_forward
(
real
*
input
,
const
int
*
sequence
,
real
*
weightData
,
real
*
output
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
,
int
beginPad
,
bool
isPadding
);
/**
* @brief Context projection backward data.
*
* @param[in] outputGrad output gradient.
* @param[in] sequence sequence index.
* @param[out] inputGrad input gradient.
* @param[in] numSequences number of sequences.
* @param[in] inputDim input sequence dimension.
* @param[in] contextLength context length.
* @param[in] contextStart context start.
*
*/
extern
void
hl_context_projection_backward_data
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
inputGrad
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
);
/**
* @brief Context projection backward weight.
*
* @param[in] outputGrad output gradient.
* @param[in] sequence sequence index.
* @param[out] weightGrad weight gradient.
* @param[in] numSequences number of sequences.
* @param[in] weightDim input sequence dimension.
* @param[in] totalPad number of extra timesteps.
* @param[in] contextLength context length.
* @param[in] contextStart context start.
* @param[in] beginPad number of extra timesteps added at the
* beginning.
*
*/
extern
void
hl_context_projection_backward_weight
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
weightGrad
,
int
numSequences
,
int
weightDim
,
int
totalPad
,
int
contextLength
,
int
contextStart
,
int
beginPad
);
/**
* @brief Memory copy from sequence to batch.
*
...
...
paddle/cuda/include/stub/hl_sequence_stub.h
浏览文件 @
ec6b13db
...
...
@@ -27,35 +27,6 @@ inline void hl_max_sequence_forward(real* input,
inline
void
hl_max_sequence_backward
(
real
*
outputGrad
,
int
*
index
,
real
*
inputGrad
,
int
numSequences
,
int
dim
)
{}
inline
void
hl_context_projection_forward
(
real
*
input
,
const
int
*
sequence
,
real
*
weightData
,
real
*
output
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
,
int
beginPad
,
bool
isPadding
)
{}
inline
void
hl_context_projection_backward_data
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
inputGrad
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
)
{}
inline
void
hl_context_projection_backward_weight
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
weightGrad
,
int
numSequences
,
int
weightDim
,
int
totalPad
,
int
contextLength
,
int
contextStart
,
int
beginPad
)
{}
inline
void
hl_sequence2batch_copy
(
real
*
batch
,
real
*
sequence
,
const
int
*
batchIndex
,
...
...
paddle/cuda/src/hl_cuda_sequence.cu
浏览文件 @
ec6b13db
...
...
@@ -90,258 +90,6 @@ void hl_max_sequence_backward(real* outputGrad,
CHECK_SYNC
(
"hl_max_sequence_backward failed"
);
}
template
<
bool
padding
>
__global__
void
KeContextProjectionForward
(
real
*
input
,
const
int
*
sequence
,
real
*
weightData
,
real
*
output
,
int
inputDim
,
int
contextLength
,
int
contextStart
,
int
beginPad
)
{
int
idx
=
threadIdx
.
x
;
int
blockSize
=
blockDim
.
x
;
int
sequenceId
=
blockIdx
.
x
;
int
seqStart
=
sequence
[
sequenceId
];
int
seqEnd
=
sequence
[
sequenceId
+
1
];
real
value
=
0
;
int
instances
=
seqEnd
-
seqStart
+
contextLength
-
1
;
output
+=
seqStart
*
inputDim
*
contextLength
;
input
+=
seqStart
*
inputDim
;
for
(
int
k
=
0
;
k
<=
inputDim
/
blockSize
;
k
++
)
{
if
(
idx
<
inputDim
)
{
for
(
int
i
=
0
;
i
<
instances
;
i
++
)
{
// i + contextStart;
if
((
i
+
contextStart
)
<
0
)
{
if
(
padding
)
{
value
=
weightData
[
i
*
inputDim
+
idx
];
}
else
{
continue
;
}
}
else
if
((
i
+
contextStart
)
>=
(
seqEnd
-
seqStart
))
{
if
(
padding
)
{
value
=
weightData
[(
beginPad
+
i
+
contextStart
-
(
seqEnd
-
seqStart
))
*
inputDim
+
idx
];
}
else
{
continue
;
}
}
else
{
value
=
input
[(
i
+
contextStart
)
*
inputDim
+
idx
];
}
int
outx
=
(
i
-
contextLength
)
<
0
?
i
:
(
contextLength
-
1
);
int
outy
=
(
i
-
contextLength
)
<
0
?
0
:
(
i
-
(
contextLength
-
1
));
real
*
output_r
=
output
+
outy
*
inputDim
*
contextLength
+
outx
*
inputDim
;
for
(
int
j
=
outy
;
j
<
seqEnd
-
seqStart
;
j
++
)
{
output_r
[
idx
]
+=
value
;
if
(
j
-
outy
==
outx
)
break
;
output_r
+=
(
contextLength
-
1
)
*
inputDim
;
}
}
}
idx
+=
blockSize
;
}
}
void
hl_context_projection_forward
(
real
*
input
,
const
int
*
sequence
,
real
*
weightData
,
real
*
output
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
,
int
beginPad
,
bool
isPadding
)
{
CHECK_NOTNULL
(
input
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
output
);
CHECK
(
!
isPadding
||
weightData
);
int
blockSize
=
128
;
int
blocksX
=
numSequences
;
int
blocksY
=
1
;
dim3
threads
(
blockSize
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
if
(
isPadding
)
{
KeContextProjectionForward
<
true
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
input
,
sequence
,
weightData
,
output
,
inputDim
,
contextLength
,
contextStart
,
beginPad
);
}
else
{
KeContextProjectionForward
<
false
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
input
,
sequence
,
weightData
,
output
,
inputDim
,
contextLength
,
contextStart
,
beginPad
);
}
CHECK_SYNC
(
"hl_context_projection_forward failed"
);
}
__global__
void
KeContextProjectionBackwardData
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
inputGrad
,
int
inputDim
,
int
contextLength
,
int
contextStart
)
{
int
idx
=
threadIdx
.
x
;
int
blockSize
=
blockDim
.
x
;
int
sequenceId
=
blockIdx
.
x
;
int
seqStart
=
sequence
[
sequenceId
];
int
seqEnd
=
sequence
[
sequenceId
+
1
];
real
value
=
0
;
int
instances
=
seqEnd
-
seqStart
+
contextLength
-
1
;
outputGrad
+=
seqStart
*
inputDim
*
contextLength
;
inputGrad
+=
seqStart
*
inputDim
;
for
(
int
k
=
0
;
k
<=
inputDim
/
blockSize
;
k
++
)
{
if
(
idx
<
inputDim
)
{
for
(
int
i
=
0
;
i
<
instances
;
i
++
)
{
if
((
i
+
contextStart
)
<
0
)
{
continue
;
}
else
if
((
i
+
contextStart
)
>=
(
seqEnd
-
seqStart
))
{
continue
;
}
else
{
// value = 0;
value
=
inputGrad
[(
i
+
contextStart
)
*
inputDim
+
idx
];
}
int
outx
=
(
i
-
contextLength
)
<
0
?
i
:
(
contextLength
-
1
);
int
outy
=
(
i
-
contextLength
)
<
0
?
0
:
(
i
-
(
contextLength
-
1
));
real
*
output_r
=
outputGrad
+
outy
*
inputDim
*
contextLength
+
outx
*
inputDim
;
for
(
int
j
=
outy
;
j
<
seqEnd
-
seqStart
;
j
++
)
{
value
+=
output_r
[
idx
];
if
(
j
-
outy
==
outx
)
break
;
output_r
+=
(
contextLength
-
1
)
*
inputDim
;
}
inputGrad
[(
i
+
contextStart
)
*
inputDim
+
idx
]
=
value
;
}
}
idx
+=
blockSize
;
}
}
void
hl_context_projection_backward_data
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
inputGrad
,
int
numSequences
,
int
inputDim
,
int
contextLength
,
int
contextStart
)
{
CHECK_NOTNULL
(
outputGrad
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
inputGrad
);
int
blockSize
=
128
;
int
blocksX
=
numSequences
;
int
blocksY
=
1
;
dim3
threads
(
blockSize
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
KeContextProjectionBackwardData
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
outputGrad
,
sequence
,
inputGrad
,
inputDim
,
contextLength
,
contextStart
);
CHECK_SYNC
(
"hl_context_projection_backward_data failed"
);
}
template
<
int
THREADS_X
,
int
THREADS_Y
>
__global__
void
KeContextProjectionBackwardWeight
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
weightGrad
,
int
numSequences
,
int
weightDim
,
int
contextLength
,
int
contextStart
,
int
beginPad
)
{
__shared__
real
sum_s
[
THREADS_Y
][
THREADS_X
];
int
padOfBlock
=
(
weightDim
+
THREADS_X
-
1
)
/
THREADS_X
;
const
int
idx
=
threadIdx
.
x
;
const
int
idy
=
threadIdx
.
y
;
int
padId
=
blockIdx
.
x
/
padOfBlock
;
int
weightIdx
=
idx
+
THREADS_X
*
(
blockIdx
.
x
%
padOfBlock
);
int
instanceId
;
real
value
=
0
;
real
*
output_r
;
sum_s
[
idy
][
idx
]
=
0.0
f
;
if
(
weightIdx
<
weightDim
)
{
for
(
int
seqId
=
idy
;
seqId
<
numSequences
;
seqId
+=
THREADS_Y
)
{
int
seqStart
=
sequence
[
seqId
];
int
seqEnd
=
sequence
[
seqId
+
1
];
output_r
=
outputGrad
+
seqStart
*
weightDim
*
contextLength
;
if
(
contextStart
<
0
)
{
if
(
padId
+
contextStart
<
0
)
{
instanceId
=
padId
;
}
else
{
// beginPad > 0;
instanceId
=
(
padId
-
beginPad
)
+
(
seqEnd
-
seqStart
)
-
contextStart
;
}
}
else
{
if
(
padId
+
(
seqEnd
-
seqStart
)
<
contextStart
)
{
continue
;
}
else
{
// beginPad == 0;
instanceId
=
padId
+
(
seqEnd
-
seqStart
)
-
contextStart
;
}
}
int
outx
=
(
instanceId
-
contextLength
)
<
0
?
instanceId
:
(
contextLength
-
1
);
int
outy
=
(
instanceId
-
contextLength
)
<
0
?
0
:
(
instanceId
-
(
contextLength
-
1
));
output_r
+=
outy
*
weightDim
*
contextLength
+
outx
*
weightDim
;
for
(
int
j
=
outy
;
j
<
seqEnd
-
seqStart
;
j
++
)
{
value
+=
output_r
[
weightIdx
];
if
(
j
-
outy
==
outx
)
break
;
output_r
+=
(
contextLength
-
1
)
*
weightDim
;
}
}
sum_s
[
idy
][
idx
]
=
value
;
}
__syncthreads
();
for
(
int
stride
=
THREADS_Y
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
idy
<
stride
)
{
sum_s
[
idy
][
idx
]
+=
sum_s
[
idy
+
stride
][
idx
];
}
__syncthreads
();
}
__syncthreads
();
if
(
weightIdx
<
weightDim
)
{
if
(
idy
==
0
)
{
weightGrad
[
padId
*
weightDim
+
weightIdx
]
+=
sum_s
[
0
][
idx
];
}
}
}
void
hl_context_projection_backward_weight
(
real
*
outputGrad
,
const
int
*
sequence
,
real
*
weightGrad
,
int
numSequences
,
int
weightDim
,
int
totalPad
,
int
contextLength
,
int
contextStart
,
int
beginPad
)
{
CHECK_NOTNULL
(
outputGrad
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
weightGrad
);
int
threadsX
=
32
;
int
threadsY
=
32
;
int
blocksX
=
totalPad
*
((
weightDim
+
threadsX
-
1
)
/
threadsX
);
dim3
threads
(
threadsX
,
threadsY
);
dim3
grid
(
blocksX
,
1
);
KeContextProjectionBackwardWeight
<
32
,
32
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
outputGrad
,
sequence
,
weightGrad
,
numSequences
,
weightDim
,
contextLength
,
contextStart
,
beginPad
);
CHECK_SYNC
(
"hl_context_projection_backward_weight failed"
);
}
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
AddRow
>
__global__
void
KeMatrixAddRows
(
real
*
output
,
real
*
table
,
...
...
paddle/function/ContextProjectionOpGpu.cu
浏览文件 @
ec6b13db
...
...
@@ -73,15 +73,30 @@ __global__ void KeContextProjectionForward(const real* input,
}
}
/**
* @brief Context projection forward.
*
* @param[in] input input sequence.
* @param[in] sequence sequence index.
* @param[in] weight padding data.
* @param[out] output output sequence.
* @param[in] num_sequences number of sequences.
* @param[in] input_dim input sequence dimension.
* @param[in] context_length context length.
* @param[in] context_start context start.
* @param[in] begin_pad number of extra timesteps added at the
* beginning.
*
*/
void
hl_context_projection_forward
(
const
real
*
input
,
const
int
*
sequence
,
const
real
*
weight
,
real
*
output
,
in
t
num_sequences
,
in
t
input_dim
,
in
t
context_length
,
size_
t
num_sequences
,
size_
t
input_dim
,
size_
t
context_length
,
int
context_start
,
in
t
begin_pad
)
{
size_
t
begin_pad
)
{
CHECK_NOTNULL
(
input
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
output
);
...
...
@@ -168,12 +183,24 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
}
}
/**
* @brief Context projection backward data.
*
* @param[in] out_grad output gradient.
* @param[in] sequence sequence index.
* @param[out] input_grad input gradient.
* @param[in] num_sequences number of sequences.
* @param[in] input_dim input sequence dimension.
* @param[in] context_length context length.
* @param[in] context_start context start.
*
*/
void
hl_context_projection_backward_data
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
input_grad
,
in
t
num_sequences
,
in
t
input_dim
,
in
t
context_length
,
size_
t
num_sequences
,
size_
t
input_dim
,
size_
t
context_length
,
int
context_start
)
{
CHECK_NOTNULL
(
out_grad
);
CHECK_NOTNULL
(
sequence
);
...
...
@@ -278,15 +305,30 @@ __global__ void KeContextProjectionBackwardWeight(real* out_grad,
}
}
/**
* @brief Context projection backward weight.
*
* @param[in] out_grad output gradient.
* @param[in] sequence sequence index.
* @param[out] w_grad weight gradient.
* @param[in] num_sequences number of sequences.
* @param[in] w_dim input sequence dimension.
* @param[in] total_pad number of extra timesteps.
* @param[in] context_length context length.
* @param[in] context_start context start.
* @param[in] begin_pad number of extra timesteps added at the
* beginning.
*
*/
void
hl_context_projection_backward_weight
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
w_grad
,
in
t
num_sequences
,
in
t
w_dim
,
size_
t
num_sequences
,
size_
t
w_dim
,
size_t
total_pad
,
in
t
context_length
,
size_
t
context_length
,
int
context_start
,
in
t
begin_pad
)
{
size_
t
begin_pad
)
{
CHECK_NOTNULL
(
out_grad
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
w_grad
);
...
...
paddle/math/Matrix.cpp
浏览文件 @
ec6b13db
...
...
@@ -1304,68 +1304,6 @@ void GpuMatrix::maxSequenceBackward(Matrix& outputGrad,
hl_max_sequence_backward
(
outGrad
,
maxIndex
,
inputGrad
,
numSequences
,
dim
);
}
void
GpuMatrix
::
contextProjectionForward
(
Matrix
&
input
,
Matrix
*
weight
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
)
{
CHECK
(
dynamic_cast
<
GpuMatrix
*>
(
&
input
));
CHECK
(
dynamic_cast
<
const
GpuIVector
*>
(
&
sequence
));
if
(
weight
)
CHECK
(
dynamic_cast
<
GpuMatrix
*>
(
weight
));
CHECK_EQ
(
getWidth
(),
input
.
getWidth
()
*
contextLength
);
hl_context_projection_forward
(
input
.
getData
(),
sequence
.
getData
(),
isPadding
?
weight
->
getData
()
:
NULL
,
getData
(),
sequence
.
getSize
()
-
1
,
input
.
getWidth
(),
contextLength
,
contextStart
,
beginPad
,
isPadding
);
}
void
GpuMatrix
::
contextProjectionBackwardData
(
Matrix
&
inputGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
)
{
CHECK
(
dynamic_cast
<
GpuMatrix
*>
(
&
inputGrad
));
CHECK
(
dynamic_cast
<
const
GpuIVector
*>
(
&
sequence
));
CHECK_EQ
(
getWidth
(),
inputGrad
.
getWidth
()
*
contextLength
);
hl_context_projection_backward_data
(
getData
(),
sequence
.
getData
(),
inputGrad
.
getData
(),
sequence
.
getSize
()
-
1
,
inputGrad
.
getWidth
(),
contextLength
,
contextStart
);
}
void
GpuMatrix
::
contextProjectionBackwardWeight
(
Matrix
&
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
int
totalPad
,
size_t
beginPad
)
{
CHECK
(
dynamic_cast
<
GpuMatrix
*>
(
&
weightGrad
));
CHECK
(
dynamic_cast
<
const
GpuIVector
*>
(
&
sequence
));
CHECK_EQ
(
getWidth
(),
weightGrad
.
getWidth
()
*
contextLength
);
hl_context_projection_backward_weight
(
getData
(),
sequence
.
getData
(),
weightGrad
.
getData
(),
sequence
.
getSize
()
-
1
,
weightGrad
.
getWidth
(),
totalPad
,
contextLength
,
contextStart
,
beginPad
);
}
void
GpuMatrix
::
paramReluForward
(
Matrix
&
data
,
Matrix
&
W
)
{
CHECK
(
data
.
useGpu_
==
true
&&
W
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
...
...
@@ -2203,113 +2141,6 @@ void CpuMatrix::maxSequenceBackward(Matrix& outputGrad,
}
}
void
CpuMatrix
::
contextProjectionForward
(
Matrix
&
input
,
Matrix
*
weight
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
)
{
auto
input_ptr
=
dynamic_cast
<
CpuMatrix
*>
(
&
input
);
auto
seq_ptr
=
dynamic_cast
<
const
CpuIVector
*>
(
&
sequence
);
CHECK
(
input_ptr
&&
seq_ptr
);
if
(
weight
)
CHECK
(
dynamic_cast
<
CpuMatrix
*>
(
weight
));
CHECK_EQ
(
getWidth
(),
input_ptr
->
getWidth
()
*
contextLength
);
const
int
*
starts
=
seq_ptr
->
getData
();
size_t
numSequences
=
seq_ptr
->
getSize
()
-
1
;
for
(
size_t
i
=
0
;
i
<
numSequences
;
++
i
)
{
for
(
int
j
=
0
;
j
<
contextLength
;
++
j
)
{
int
begin
=
starts
[
i
]
+
contextStart
+
j
;
int
end
=
starts
[
i
+
1
]
+
contextStart
+
j
;
int
dstBegin
=
starts
[
i
];
int
dstEnd
=
starts
[
i
+
1
];
if
(
begin
<
starts
[
i
])
{
int64_t
padSize
=
std
::
min
(
starts
[
i
]
-
begin
,
starts
[
i
+
1
]
-
starts
[
i
]);
MatrixPtr
mat
=
this
->
subMatrix
(
starts
[
i
],
padSize
);
if
(
isPadding
)
{
MatrixPtr
sub
=
weight
->
subMatrix
(
j
,
padSize
);
mat
->
addAtOffset
(
*
sub
,
j
*
input_ptr
->
getWidth
());
}
dstBegin
=
starts
[
i
]
+
padSize
;
begin
=
starts
[
i
];
}
if
(
end
>
starts
[
i
+
1
])
{
int64_t
padSize
=
std
::
min
(
end
-
starts
[
i
+
1
],
starts
[
i
+
1
]
-
starts
[
i
]);
MatrixPtr
mat
=
this
->
subMatrix
(
starts
[
i
+
1
]
-
padSize
,
padSize
);
if
(
isPadding
)
{
MatrixPtr
sub
=
weight
->
subMatrix
(
beginPad
+
contextStart
+
j
-
padSize
,
padSize
);
mat
->
addAtOffset
(
*
sub
,
j
*
input_ptr
->
getWidth
());
}
dstEnd
=
starts
[
i
+
1
]
-
padSize
;
end
=
starts
[
i
+
1
];
}
if
(
end
<=
begin
)
continue
;
MatrixPtr
src
=
input_ptr
->
subMatrix
(
begin
,
end
-
begin
);
MatrixPtr
dst
=
this
->
subMatrix
(
dstBegin
,
dstEnd
-
dstBegin
);
dst
->
addAtOffset
(
*
src
,
j
*
input_ptr
->
getWidth
());
}
}
}
void
CpuMatrix
::
contextProjectionBackward
(
Matrix
*
inputGrad
,
Matrix
*
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
)
{
if
(
inputGrad
)
CHECK
(
dynamic_cast
<
CpuMatrix
*>
(
inputGrad
));
if
(
weightGrad
)
CHECK
(
dynamic_cast
<
CpuMatrix
*>
(
weightGrad
));
CHECK
(
dynamic_cast
<
const
CpuIVector
*>
(
&
sequence
));
int64_t
inputDim
=
inputGrad
?
inputGrad
->
getWidth
()
:
weightGrad
?
weightGrad
->
getWidth
()
:
0
;
CHECK_EQ
(
getWidth
(),
inputDim
*
contextLength
);
const
int
*
starts
=
sequence
.
getData
();
size_t
numSequences
=
sequence
.
getSize
()
-
1
;
for
(
size_t
i
=
0
;
i
<
numSequences
;
++
i
)
{
for
(
int
j
=
0
;
j
<
contextLength
;
++
j
)
{
int
begin
=
starts
[
i
]
+
contextStart
+
j
;
int
end
=
starts
[
i
+
1
]
+
contextStart
+
j
;
int
dstBegin
=
starts
[
i
];
int
dstEnd
=
starts
[
i
+
1
];
if
(
begin
<
starts
[
i
])
{
int64_t
padSize
=
std
::
min
(
starts
[
i
]
-
begin
,
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
isPadding
&&
weightGrad
)
{
MatrixPtr
mat
=
this
->
subMatrix
(
starts
[
i
],
padSize
);
MatrixPtr
sub
=
weightGrad
->
subMatrix
(
j
,
padSize
);
sub
->
addAtOffset
(
*
mat
,
j
*
inputDim
);
}
dstBegin
=
starts
[
i
]
+
padSize
;
begin
=
starts
[
i
];
}
if
(
end
>
starts
[
i
+
1
])
{
int64_t
padSize
=
std
::
min
(
end
-
starts
[
i
+
1
],
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
isPadding
&&
weightGrad
)
{
MatrixPtr
mat
=
this
->
subMatrix
(
starts
[
i
+
1
]
-
padSize
,
padSize
);
MatrixPtr
sub
=
weightGrad
->
subMatrix
(
beginPad
+
contextStart
+
j
-
padSize
,
padSize
);
sub
->
addAtOffset
(
*
mat
,
j
*
inputDim
);
}
dstEnd
=
starts
[
i
+
1
]
-
padSize
;
end
=
starts
[
i
+
1
];
}
if
(
end
<=
begin
)
continue
;
if
(
!
inputGrad
)
continue
;
MatrixPtr
src
=
inputGrad
->
subMatrix
(
begin
,
end
-
begin
);
MatrixPtr
dst
=
this
->
subMatrix
(
dstBegin
,
dstEnd
-
dstBegin
);
src
->
addAtOffset
(
*
dst
,
j
*
inputDim
);
}
}
}
inline
void
vecAddTo
(
real
*
a
,
const
real
*
b
,
size_t
len
)
{
for
(
unsigned
int
i
=
0
;
i
<
len
;
++
i
)
{
a
[
i
]
+=
b
[
i
];
...
...
paddle/math/Matrix.h
浏览文件 @
ec6b13db
...
...
@@ -972,42 +972,6 @@ public:
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
contextProjectionForward
(
Matrix
&
input
,
Matrix
*
weight
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
contextProjectionBackward
(
Matrix
*
inputGrad
,
Matrix
*
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
contextProjectionBackwardData
(
Matrix
&
inputGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
contextProjectionBackwardWeight
(
Matrix
&
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
int
totalPad
,
size_t
beginPad
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
/**
* @code
* this.row[i] += table.row[ids[i]]
...
...
@@ -1442,26 +1406,6 @@ public:
const
IVector
&
sequence
,
IVector
&
index
);
void
contextProjectionForward
(
Matrix
&
input
,
Matrix
*
weight
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
);
void
contextProjectionBackwardData
(
Matrix
&
inputGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
);
void
contextProjectionBackwardWeight
(
Matrix
&
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
int
totalPad
,
size_t
beginPad
);
void
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
...
...
@@ -1648,22 +1592,6 @@ public:
const
IVector
&
sequence
,
IVector
&
index
);
void
contextProjectionForward
(
Matrix
&
input
,
Matrix
*
weight
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
);
void
contextProjectionBackward
(
Matrix
*
inputGrad
,
Matrix
*
weightGrad
,
const
IVector
&
sequence
,
int
contextLength
,
int
contextStart
,
size_t
beginPad
,
bool
isPadding
);
real
*
getRow
(
size_t
row
)
{
return
BaseMatrix
::
rowBuf
(
row
);
}
virtual
real
*
getRowBuf
(
size_t
row
)
{
return
getRow
(
row
);
}
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
ec6b13db
...
...
@@ -29,148 +29,6 @@ using namespace std; // NOLINT
using
autotest
::
TensorCheckEqual
;
using
autotest
::
TensorCheckErr
;
void
testMatrixProjectionForward
(
int
contextStart
,
int
contextLength
,
bool
padding
,
int
batchSize
,
int
inputDim
)
{
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInput
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
cpuInput
->
randomizeUniform
();
gpuInput
->
copyFrom
(
*
cpuInput
);
int
pad
=
std
::
max
(
0
,
-
contextStart
)
+
std
::
max
(
0
,
contextStart
+
contextLength
-
1
);
if
(
pad
==
0
)
padding
=
false
;
MatrixPtr
cpuWeight
=
nullptr
;
MatrixPtr
gpuWeight
=
nullptr
;
if
(
padding
)
{
cpuWeight
=
std
::
make_shared
<
CpuMatrix
>
(
pad
,
inputDim
);
gpuWeight
=
std
::
make_shared
<
GpuMatrix
>
(
pad
,
inputDim
);
cpuWeight
->
randomizeUniform
();
gpuWeight
->
copyFrom
(
*
cpuWeight
);
}
IVectorPtr
cpuSequence
;
generateSequenceStartPositions
(
batchSize
,
cpuSequence
);
IVectorPtr
gpuSequence
=
IVector
::
create
(
cpuSequence
->
getSize
(),
true
);
gpuSequence
->
copyFrom
(
*
cpuSequence
);
MatrixPtr
cpuOutput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
*
contextLength
);
MatrixPtr
gpuOutput
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
*
contextLength
);
cpuOutput
->
randomizeUniform
();
gpuOutput
->
copyFrom
(
*
cpuOutput
);
// calculate
int
beginPad
=
std
::
max
(
0
,
-
contextStart
);
cpuOutput
->
contextProjectionForward
(
*
cpuInput
,
cpuWeight
.
get
(),
*
cpuSequence
,
contextLength
,
contextStart
,
beginPad
,
padding
);
gpuOutput
->
contextProjectionForward
(
*
gpuInput
,
gpuWeight
.
get
(),
*
gpuSequence
,
contextLength
,
contextStart
,
beginPad
,
padding
);
TensorCheckEqual
(
*
cpuOutput
,
*
gpuOutput
);
}
void
testMatrixProjectionBackward
(
int
contextStart
,
int
contextLength
,
bool
padding
,
int
batchSize
,
int
inputDim
)
{
MatrixPtr
cpuOutputGrad
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
*
contextLength
);
MatrixPtr
gpuOutputGrad
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
*
contextLength
);
cpuOutputGrad
->
randomizeUniform
();
gpuOutputGrad
->
copyFrom
(
*
cpuOutputGrad
);
IVectorPtr
cpuSequence
;
generateSequenceStartPositions
(
batchSize
,
cpuSequence
);
IVectorPtr
gpuSequence
=
IVector
::
create
(
cpuSequence
->
getSize
(),
true
);
gpuSequence
->
copyFrom
(
*
cpuSequence
);
MatrixPtr
cpuInputGrad
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInputGrad
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
cpuInputGrad
->
randomizeUniform
();
gpuInputGrad
->
copyFrom
(
*
cpuInputGrad
);
int
pad
=
std
::
max
(
0
,
-
contextStart
)
+
std
::
max
(
0
,
contextStart
+
contextLength
-
1
);
if
(
pad
==
0
)
padding
=
false
;
MatrixPtr
cpuWeightGrad
=
nullptr
;
MatrixPtr
gpuWeightGrad
=
nullptr
;
if
(
padding
)
{
cpuWeightGrad
=
std
::
make_shared
<
CpuMatrix
>
(
pad
,
inputDim
);
gpuWeightGrad
=
std
::
make_shared
<
GpuMatrix
>
(
pad
,
inputDim
);
cpuWeightGrad
->
randomizeUniform
();
gpuWeightGrad
->
copyFrom
(
*
cpuWeightGrad
);
}
// calculate
int
beginPad
=
std
::
max
(
0
,
-
contextStart
);
cpuOutputGrad
->
contextProjectionBackward
(
cpuInputGrad
.
get
(),
cpuWeightGrad
.
get
(),
*
cpuSequence
,
contextLength
,
contextStart
,
beginPad
,
padding
);
gpuOutputGrad
->
contextProjectionBackwardData
(
*
gpuInputGrad
,
*
gpuSequence
,
contextLength
,
contextStart
);
if
(
padding
)
{
gpuOutputGrad
->
contextProjectionBackwardWeight
(
*
gpuWeightGrad
,
*
gpuSequence
,
contextLength
,
contextStart
,
pad
,
beginPad
);
}
TensorCheckErr
(
*
cpuInputGrad
,
*
gpuInputGrad
);
if
(
padding
)
{
TensorCheckErr
(
*
cpuWeightGrad
,
*
gpuWeightGrad
);
}
}
TEST
(
Matrix
,
projection
)
{
for
(
auto
contextStart
:
{
-
5
,
-
3
,
-
1
,
0
,
3
})
{
for
(
auto
contextLength
:
{
1
,
2
,
5
,
7
})
{
for
(
auto
trainablePadding
:
{
false
,
true
})
{
for
(
auto
batchSize
:
{
1
,
2
,
5
,
20
,
100
})
{
for
(
auto
inputDim
:
{
15
,
32
,
63
,
128
,
200
})
{
VLOG
(
3
)
<<
" contextStart="
<<
contextStart
<<
" contextLength="
<<
contextLength
<<
" trainablePadding="
<<
trainablePadding
<<
" batchSize="
<<
batchSize
<<
" inputDim="
<<
inputDim
;
testMatrixProjectionForward
(
contextStart
,
contextLength
,
trainablePadding
,
batchSize
,
inputDim
);
testMatrixProjectionBackward
(
contextStart
,
contextLength
,
trainablePadding
,
batchSize
,
inputDim
);
}
}
}
}
}
}
void
testMatrixMaxSequence
(
int
batchSize
,
int
inputDim
)
{
// forward
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录