Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0ae5ac16
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0ae5ac16
编写于
1月 16, 2017
作者:
H
hedaoyuan
提交者:
GitHub
1月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1080 from tianbingsz/paddle_func_context
Context Projection Paddle Function-- follow comments
上级
c13540a6
e9794214
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
291 addition
and
246 deletion
+291
-246
paddle/function/BufferArg.cpp
paddle/function/BufferArg.cpp
+8
-4
paddle/function/BufferArg.h
paddle/function/BufferArg.h
+24
-9
paddle/function/CMakeLists.txt
paddle/function/CMakeLists.txt
+1
-1
paddle/function/ContextProjectionOp.cpp
paddle/function/ContextProjectionOp.cpp
+154
-105
paddle/function/ContextProjectionOp.h
paddle/function/ContextProjectionOp.h
+11
-11
paddle/function/ContextProjectionOpGpu.cu
paddle/function/ContextProjectionOpGpu.cu
+13
-11
paddle/function/ContextProjectionOpTest.cpp
paddle/function/ContextProjectionOpTest.cpp
+36
-35
paddle/function/Function.cpp
paddle/function/Function.cpp
+6
-0
paddle/function/Function.h
paddle/function/Function.h
+4
-0
paddle/function/FunctionTest.h
paddle/function/FunctionTest.h
+17
-55
paddle/gserver/layers/ContextProjection.cpp
paddle/gserver/layers/ContextProjection.cpp
+17
-15
未找到文件。
paddle/function/BufferArg.cpp
浏览文件 @
0ae5ac16
...
@@ -20,23 +20,27 @@ limitations under the License. */
...
@@ -20,23 +20,27 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
const
SequenceArg
&
BufferArg
::
sequence
()
const
{
const
SequenceArg
&
BufferArg
::
sequence
()
const
{
//
CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA);
CHECK_EQ
(
bufferType_
,
TENSOR_SEQUENCE_DATA
);
return
dynamic_cast
<
const
SequenceArg
&>
(
*
this
);
return
dynamic_cast
<
const
SequenceArg
&>
(
*
this
);
}
}
const
SparseMatrixArg
&
BufferArg
::
sparse
()
const
{
const
SparseMatrixArg
&
BufferArg
::
sparse
()
const
{
//
CHECK_EQ(bufferType_, TENSOR_SPARSE);
CHECK_EQ
(
bufferType_
,
TENSOR_SPARSE
);
return
dynamic_cast
<
const
SparseMatrixArg
&>
(
*
this
);
return
dynamic_cast
<
const
SparseMatrixArg
&>
(
*
this
);
}
}
SparseMatrixArg
::
SparseMatrixArg
(
const
CpuSparseMatrix
&
sparse
,
ArgType
argType
)
SparseMatrixArg
::
SparseMatrixArg
(
const
CpuSparseMatrix
&
sparse
,
ArgType
argType
)
:
BufferArg
(
sparse
,
argType
),
:
BufferArg
(
sparse
,
argType
),
row_
(
reinterpret_cast
<
void
*>
(
sparse
.
getRows
()),
VALUE_TYPE_INT32
),
row_
(
reinterpret_cast
<
void
*>
(
sparse
.
getRows
()),
VALUE_TYPE_INT32
),
col_
(
reinterpret_cast
<
void
*>
(
sparse
.
getCols
()),
VALUE_TYPE_INT32
)
{}
col_
(
reinterpret_cast
<
void
*>
(
sparse
.
getCols
()),
VALUE_TYPE_INT32
)
{
bufferType_
=
TENSOR_SPARSE
;
}
SparseMatrixArg
::
SparseMatrixArg
(
const
GpuSparseMatrix
&
sparse
,
ArgType
argType
)
SparseMatrixArg
::
SparseMatrixArg
(
const
GpuSparseMatrix
&
sparse
,
ArgType
argType
)
:
BufferArg
(
sparse
,
argType
),
:
BufferArg
(
sparse
,
argType
),
row_
(
reinterpret_cast
<
void
*>
(
sparse
.
getRows
()),
VALUE_TYPE_INT32
),
row_
(
reinterpret_cast
<
void
*>
(
sparse
.
getRows
()),
VALUE_TYPE_INT32
),
col_
(
reinterpret_cast
<
void
*>
(
sparse
.
getCols
()),
VALUE_TYPE_INT32
)
{}
col_
(
reinterpret_cast
<
void
*>
(
sparse
.
getCols
()),
VALUE_TYPE_INT32
)
{
bufferType_
=
TENSOR_SPARSE
;
}
}
// namespace paddle
}
// namespace paddle
paddle/function/BufferArg.h
浏览文件 @
0ae5ac16
...
@@ -23,10 +23,11 @@ limitations under the License. */
...
@@ -23,10 +23,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
enum
BufferType
{
enum
BufferType
{
TENSOR_NORMAL
=
0
,
TENSOR_UNKNOWN
=
0
,
TENSOR_SEQUENCE_ID
=
1
,
TENSOR_NORMAL
=
1
,
TENSOR_SEQUENCE_DATA
=
2
,
TENSOR_SEQUENCE_ID
=
2
,
TENSOR_SPARSE
=
3
TENSOR_SEQUENCE_DATA
=
3
,
TENSOR_SPARSE
=
4
};
};
enum
SparseDataType
{
enum
SparseDataType
{
...
@@ -86,6 +87,7 @@ public:
...
@@ -86,6 +87,7 @@ public:
valueType_
(
DataType
<
real
>::
value
),
valueType_
(
DataType
<
real
>::
value
),
shape_
(
2
),
shape_
(
2
),
argType_
(
argType
)
{
argType_
(
argType
)
{
bufferType_
=
TENSOR_NORMAL
;
shape_
.
setDim
(
0
,
matrix
.
getHeight
());
shape_
.
setDim
(
0
,
matrix
.
getHeight
());
shape_
.
setDim
(
1
,
matrix
.
getWidth
());
shape_
.
setDim
(
1
,
matrix
.
getWidth
());
}
}
...
@@ -98,6 +100,7 @@ public:
...
@@ -98,6 +100,7 @@ public:
valueType_
(
DataType
<
real
>::
value
),
valueType_
(
DataType
<
real
>::
value
),
shape_
(
shape
),
shape_
(
shape
),
argType_
(
argType
)
{
argType_
(
argType
)
{
bufferType_
=
TENSOR_NORMAL
;
CHECK_EQ
(
matrix
.
getElementCnt
(),
shape
.
getElements
());
CHECK_EQ
(
matrix
.
getElementCnt
(),
shape
.
getElements
());
}
}
...
@@ -107,6 +110,7 @@ public:
...
@@ -107,6 +110,7 @@ public:
valueType_
(
DataType
<
real
>::
value
),
valueType_
(
DataType
<
real
>::
value
),
shape_
(
1
),
shape_
(
1
),
argType_
(
argType
)
{
argType_
(
argType
)
{
bufferType_
=
TENSOR_NORMAL
;
shape_
.
setDim
(
0
,
vector
.
getSize
());
shape_
.
setDim
(
0
,
vector
.
getSize
());
}
}
...
@@ -116,6 +120,7 @@ public:
...
@@ -116,6 +120,7 @@ public:
valueType_
(
VALUE_TYPE_INT32
),
valueType_
(
VALUE_TYPE_INT32
),
shape_
(
1
),
shape_
(
1
),
argType_
(
argType
)
{
argType_
(
argType
)
{
bufferType_
=
TENSOR_NORMAL
;
shape_
.
setDim
(
0
,
vector
.
getSize
());
shape_
.
setDim
(
0
,
vector
.
getSize
());
}
}
...
@@ -150,6 +155,8 @@ public:
...
@@ -150,6 +155,8 @@ public:
ValueType
valueType
()
const
{
return
valueType_
;
}
ValueType
valueType
()
const
{
return
valueType_
;
}
BufferType
bufferType
()
const
{
return
bufferType_
;
}
BufferType
bufferType
()
const
{
return
bufferType_
;
}
const
TensorShape
&
shape
()
const
{
return
shape_
;
}
const
TensorShape
&
shape
()
const
{
return
shape_
;
}
bool
isSparse
()
const
{
return
(
TENSOR_SPARSE
==
bufferType_
);
}
bool
isSequenceArg
()
const
{
return
TENSOR_SEQUENCE_DATA
==
bufferType_
;
}
const
SequenceArg
&
sequence
()
const
;
const
SequenceArg
&
sequence
()
const
;
const
SparseMatrixArg
&
sparse
()
const
;
const
SparseMatrixArg
&
sparse
()
const
;
...
@@ -158,8 +165,8 @@ protected:
...
@@ -158,8 +165,8 @@ protected:
void
*
buf_
;
void
*
buf_
;
ValueType
valueType_
;
ValueType
valueType_
;
TensorShape
shape_
;
TensorShape
shape_
;
BufferType
bufferType_
;
BufferType
bufferType_
{
TENSOR_UNKNOWN
}
;
ArgType
argType_
=
UNSPECIFIED
;
ArgType
argType_
{
UNSPECIFIED
}
;
// leading dimensions. The size is dims_.size()
// leading dimensions. The size is dims_.size()
// Dims lds_;
// Dims lds_;
};
};
...
@@ -174,11 +181,13 @@ public:
...
@@ -174,11 +181,13 @@ public:
const
TensorShape
&
shape
,
const
TensorShape
&
shape
,
ArgType
argType
=
UNSPECIFIED
)
ArgType
argType
=
UNSPECIFIED
)
:
BufferArg
(
buf
,
VALUE_TYPE_INT32
,
shape
,
argType
)
{
:
BufferArg
(
buf
,
VALUE_TYPE_INT32
,
shape
,
argType
)
{
bufferType_
=
TENSOR_SEQUENCE_ID
;
CHECK_EQ
(
shape_
.
ndims
(),
(
size_t
)
1
);
CHECK_EQ
(
shape_
.
ndims
(),
(
size_t
)
1
);
numSeqs_
=
shape_
[
0
]
-
1
;
numSeqs_
=
shape_
[
0
]
-
1
;
}
}
SequenceIdArg
(
const
IVector
&
vector
)
:
BufferArg
(
vector
)
{
SequenceIdArg
(
const
IVector
&
vector
)
:
BufferArg
(
vector
)
{
bufferType_
=
TENSOR_SEQUENCE_ID
;
numSeqs_
=
shape_
[
0
]
-
1
;
numSeqs_
=
shape_
[
0
]
-
1
;
}
}
...
@@ -190,7 +199,7 @@ private:
...
@@ -190,7 +199,7 @@ private:
size_t
numSeqs_
;
size_t
numSeqs_
;
};
};
// sequence data
// sequence data
{seqId(vec), buf(matrix)}
class
SequenceArg
:
public
BufferArg
{
class
SequenceArg
:
public
BufferArg
{
public:
public:
SequenceArg
(
void
*
buf
,
SequenceArg
(
void
*
buf
,
...
@@ -199,17 +208,22 @@ public:
...
@@ -199,17 +208,22 @@ public:
const
SequenceIdArg
&
startPositions
,
const
SequenceIdArg
&
startPositions
,
ArgType
argType
=
UNSPECIFIED
)
ArgType
argType
=
UNSPECIFIED
)
:
BufferArg
(
buf
,
valueType
,
shape
,
argType
),
:
BufferArg
(
buf
,
valueType
,
shape
,
argType
),
startPositions_
(
startPositions
)
{}
startPositions_
(
startPositions
)
{
bufferType_
=
TENSOR_SEQUENCE_DATA
;
}
SequenceArg
(
const
Matrix
&
matrix
,
SequenceArg
(
const
Matrix
&
matrix
,
const
IVector
&
vector
,
const
IVector
&
vector
,
ArgType
argType
=
UNSPECIFIED
)
ArgType
argType
=
UNSPECIFIED
)
:
BufferArg
(
matrix
,
argType
),
startPositions_
(
vector
)
{}
:
BufferArg
(
matrix
,
argType
),
startPositions_
(
vector
)
{
bufferType_
=
TENSOR_SEQUENCE_DATA
;
}
~
SequenceArg
()
{}
~
SequenceArg
()
{}
void
*
getIdBuf
()
const
{
return
startPositions_
.
data
();
}
void
*
getIdBuf
()
const
{
return
startPositions_
.
data
();
}
size_t
numSeqs
()
const
{
return
startPositions_
.
numSeqs
();
}
size_t
numSeqs
()
const
{
return
startPositions_
.
numSeqs
();
}
const
SequenceIdArg
&
getSequenceIds
()
const
{
return
startPositions_
;
}
private:
private:
SequenceIdArg
startPositions_
;
SequenceIdArg
startPositions_
;
...
@@ -235,6 +249,7 @@ public:
...
@@ -235,6 +249,7 @@ public:
nnz_
(
nnz
),
nnz_
(
nnz
),
format_
(
format
),
format_
(
format
),
type_
(
type
)
{
type_
(
type
)
{
bufferType_
=
TENSOR_SPARSE
;
CHECK
((
valueType
==
VALUE_TYPE_FLOAT
)
||
(
valueType
==
VALUE_TYPE_DOUBLE
));
CHECK
((
valueType
==
VALUE_TYPE_FLOAT
)
||
(
valueType
==
VALUE_TYPE_DOUBLE
));
CHECK_EQ
(
shape_
.
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
shape_
.
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
row_
.
shape
().
ndims
(),
(
size_t
)
1
);
CHECK_EQ
(
row_
.
shape
().
ndims
(),
(
size_t
)
1
);
...
...
paddle/function/CMakeLists.txt
浏览文件 @
0ae5ac16
...
@@ -24,7 +24,7 @@ if(WITH_TESTING)
...
@@ -24,7 +24,7 @@ if(WITH_TESTING)
add_simple_unittest
(
TensorTypeTest
)
add_simple_unittest
(
TensorTypeTest
)
add_simple_unittest
(
BufferArgTest
)
add_simple_unittest
(
BufferArgTest
)
add_simple_unittest
(
FunctionTest
)
add_simple_unittest
(
FunctionTest
)
#
add_simple_unittest(ContextProjectionOpTest)
add_simple_unittest
(
ContextProjectionOpTest
)
endif
()
endif
()
endif
()
endif
()
...
...
paddle/function/ContextProjectionOp.cpp
浏览文件 @
0ae5ac16
...
@@ -17,7 +17,10 @@ limitations under the License. */
...
@@ -17,7 +17,10 @@ limitations under the License. */
#include "paddle/math/Vector.h"
#include "paddle/math/Vector.h"
namespace
paddle
{
namespace
paddle
{
/**
* Context Projection Forward with CPU Matrix Device.
*
*/
template
<
>
template
<
>
void
ContextProjectionForward
<
DEVICE_TYPE_CPU
>
(
CpuMatrix
&
out_mat
,
void
ContextProjectionForward
<
DEVICE_TYPE_CPU
>
(
CpuMatrix
&
out_mat
,
const
CpuMatrix
&
input_mat
,
const
CpuMatrix
&
input_mat
,
...
@@ -70,10 +73,30 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
...
@@ -70,10 +73,30 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
}
}
/**
/**
* \param inputs[0] input value.
* Paddle Function for Context Projection Forward.
* \param inputs[1] input weight.
* Calculate the output layer value sequence after context projection.
* \param inputs[2] input sequence.
*
* \param outputs[0] output value.
* What is Context Projection for a sequence?
* For example, assumed input (x) has 4 words and the dimension of each word
* representation is 2. If we use zero to pad instead of learned weight to pad,
* and the context_lenth is 3, the output (y) is:
*
* @code
* x = [a1, a2;
* b1, b2;
* c1, c2;
* d1, d2]
* y = [0, 0, a1, a2, b1, b2;
* a1, a2, b1, b2, c1, c2;
* b1, b2, c1, c2, d1, d2;
* c1, c2, d1, d2, 0, 0]
* @endcode
*
* \param outputs[0].matrix output layer value, n * (d * l)
* \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix input layer value, n * d
* \param inputs[0].vector start position sequence, n * 1
* \param inputs[1].matrix input layer weight, pad * d
*/
*/
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
class
ContextProjectionForwardFunc
:
public
FunctionBase
{
class
ContextProjectionForwardFunc
:
public
FunctionBase
{
...
@@ -85,28 +108,38 @@ public:
...
@@ -85,28 +108,38 @@ public:
}
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK
_EQ
((
size_t
)
3
,
inputs
.
size
());
CHECK
(
1
==
inputs
.
size
()
||
2
==
inputs
.
size
());
CHECK_EQ
((
size_t
)
1
,
outputs
.
size
());
CHECK_EQ
((
size_t
)
1
,
outputs
.
size
());
CHECK
(
inputs
[
0
].
isSequenceArg
()
&&
outputs
[
0
].
isSequenceArg
())
<<
"SequenceArg required here"
;
const
auto
val_seqs
=
dynamic_cast
<
const
SequenceArg
&>
(
inputs
[
0
]);
auto
out_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
outputs
[
0
]);
CHECK
(
outputs
[
0
].
data
()
&&
inputs
[
0
].
data
()
&&
inputs
[
2
].
data
());
CHECK
(
out_seq
.
data
()
&&
val_seqs
.
data
()
&&
CHECK_EQ
(
outputs
[
0
].
shape
().
ndims
(),
(
size_t
)
2
);
val_seqs
.
getSequenceIds
().
data
());
CHECK_EQ
(
inputs
[
0
].
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
out_seq
.
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
inputs
[
1
].
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
val_seqs
.
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
inputs
[
2
].
shape
().
ndims
(),
(
size_t
)
1
);
CHECK_EQ
(
val_seqs
.
getSequenceIds
().
shape
().
ndims
(),
(
size_t
)
1
);
if
(
2
==
inputs
.
size
())
{
CHECK_EQ
(
inputs
[
1
].
shape
().
ndims
(),
(
size_t
)
2
);
}
/// dim of output = dim of input * context_length
/// dim of output = dim of input * context_length
CHECK_EQ
(
outputs
[
0
].
shape
()[
1
],
inputs
[
0
].
shape
()[
1
]
*
context_length_
);
CHECK_EQ
(
out_seq
.
shape
()[
1
],
val_seqs
.
shape
()[
1
]
*
context_length_
);
/// dim of input == dim of weight
CHECK_EQ
(
inputs
[
0
].
shape
()[
1
],
inputs
[
1
].
shape
()[
1
]);
/// input and output has the same batch_size
/// input and output has the same batch_size
CHECK_EQ
(
inputs
[
0
].
shape
()[
0
],
outputs
[
0
].
shape
()[
0
]);
CHECK_EQ
(
val_seqs
.
shape
()[
0
],
out_seq
.
shape
()[
0
]);
/// dim of input == dim of weight
if
(
2
==
inputs
.
size
())
{
CHECK_EQ
(
val_seqs
.
shape
()[
1
],
inputs
[
1
].
shape
()[
1
]);
}
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
CHECK_EQ
(
out_seq
.
getArgType
(),
ADD_TO
);
auto
out_mat
=
outputs
[
0
].
matrix
<
Device
>
();
auto
out_mat
=
out_seq
.
matrix
<
Device
>
();
auto
in_mat
=
inputs
[
0
].
matrix
<
Device
>
();
const
auto
in_mat
=
val_seqs
.
matrix
<
Device
>
();
auto
w_mat
=
!
inputs
[
1
].
data
()
const
auto
w_mat
=
?
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
)
(
2
==
inputs
.
size
())
:
inputs
[
1
].
matrix
<
Device
>
();
?
inputs
[
1
].
matrix
<
Device
>
()
auto
seq_vec
=
inputs
[
2
].
vector
<
int
,
Device
>
();
:
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
);
const
auto
seq_vec
=
val_seqs
.
getSequenceIds
().
vector
<
int
,
Device
>
();
ContextProjectionForward
<
Device
>
(
out_mat
,
ContextProjectionForward
<
Device
>
(
out_mat
,
in_mat
,
in_mat
,
w_mat
,
w_mat
,
...
@@ -122,8 +155,12 @@ private:
...
@@ -122,8 +155,12 @@ private:
size_t
begin_pad_
;
size_t
begin_pad_
;
};
};
/**
* Context Projection Backward with CPU Matrix Device.
*
*/
template
<
>
template
<
>
void
ContextProjectionBackward
<
DEVICE_TYPE_CPU
>
(
CpuMatrix
&
out_grad_mat
,
void
ContextProjectionBackward
<
DEVICE_TYPE_CPU
>
(
const
CpuMatrix
&
out_grad_mat
,
CpuMatrix
&
in_grad_mat
,
CpuMatrix
&
in_grad_mat
,
CpuMatrix
&
w_grad_mat
,
CpuMatrix
&
w_grad_mat
,
const
CpuIVector
&
seq_vec
,
const
CpuIVector
&
seq_vec
,
...
@@ -146,7 +183,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
...
@@ -146,7 +183,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
int64_t
pad_size
=
int64_t
pad_size
=
std
::
min
(
starts
[
i
]
-
begin
,
starts
[
i
+
1
]
-
starts
[
i
]);
std
::
min
(
starts
[
i
]
-
begin
,
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
is_padding
&&
w_grad_mat
)
{
if
(
is_padding
&&
w_grad_mat
)
{
MatrixPtr
mat
=
out_grad_mat
.
subMatrix
(
starts
[
i
],
pad_size
);
MatrixPtr
mat
=
const_cast
<
CpuMatrix
&>
(
out_grad_mat
)
.
subMatrix
(
starts
[
i
],
pad_size
);
MatrixPtr
sub
=
w_grad_mat
.
subMatrix
(
j
,
pad_size
);
MatrixPtr
sub
=
w_grad_mat
.
subMatrix
(
j
,
pad_size
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
}
}
...
@@ -157,8 +195,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
...
@@ -157,8 +195,8 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
int64_t
pad_size
=
int64_t
pad_size
=
std
::
min
(
end
-
starts
[
i
+
1
],
starts
[
i
+
1
]
-
starts
[
i
]);
std
::
min
(
end
-
starts
[
i
+
1
],
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
is_padding
&&
w_grad_mat
)
{
if
(
is_padding
&&
w_grad_mat
)
{
MatrixPtr
mat
=
MatrixPtr
mat
=
const_cast
<
CpuMatrix
&>
(
out_grad_mat
)
out_grad_mat
.
subMatrix
(
starts
[
i
+
1
]
-
pad_size
,
pad_size
);
.
subMatrix
(
starts
[
i
+
1
]
-
pad_size
,
pad_size
);
MatrixPtr
sub
=
w_grad_mat
.
subMatrix
(
MatrixPtr
sub
=
w_grad_mat
.
subMatrix
(
begin_pad
+
context_start
+
j
-
pad_size
,
pad_size
);
begin_pad
+
context_start
+
j
-
pad_size
,
pad_size
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
...
@@ -169,17 +207,22 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
...
@@ -169,17 +207,22 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(CpuMatrix& out_grad_mat,
if
(
end
<=
begin
)
continue
;
if
(
end
<=
begin
)
continue
;
if
(
!
in_grad_mat
)
continue
;
if
(
!
in_grad_mat
)
continue
;
MatrixPtr
src
=
in_grad_mat
.
subMatrix
(
begin
,
end
-
begin
);
MatrixPtr
src
=
in_grad_mat
.
subMatrix
(
begin
,
end
-
begin
);
MatrixPtr
dst
=
out_grad_mat
.
subMatrix
(
dst_begin
,
dst_end
-
dst_begin
);
MatrixPtr
dst
=
const_cast
<
CpuMatrix
&>
(
out_grad_mat
)
.
subMatrix
(
dst_begin
,
dst_end
-
dst_begin
);
src
->
addAtOffset
(
*
dst
,
j
*
input_dim
);
src
->
addAtOffset
(
*
dst
,
j
*
input_dim
);
}
}
}
}
}
}
/**
/**
* \param inputs[0] input grad.
* Context Projection Backward Function.
* \param inputs[1] weight grad.
* Update the weight gradient and input layer gradient with backprop
* \param inputs[2] input sequence.
*
* \param outputs[0] output value.
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vector start position sequence, n * 1
* \param outputs[0].matrix input layer grad, n * d
* \param outputs[0].vector start position sequence, n * 1
* \param outputs[1] weight grad, pad * d
*/
*/
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
class
ContextProjectionBackwardFunc
:
public
FunctionBase
{
class
ContextProjectionBackwardFunc
:
public
FunctionBase
{
...
@@ -193,32 +236,36 @@ public:
...
@@ -193,32 +236,36 @@ public:
}
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
((
size_t
)
3
,
inputs
.
size
());
CHECK_EQ
((
size_t
)
1
,
inputs
.
size
());
CHECK_EQ
((
size_t
)
1
,
outputs
.
size
());
CHECK_EQ
((
size_t
)
2
,
outputs
.
size
());
CHECK
(
inputs
[
0
].
isSequenceArg
()
&&
outputs
[
0
].
isSequenceArg
())
<<
"SequenceArg required here"
;
const
auto
in_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
inputs
[
0
]);
auto
out_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
outputs
[
0
]);
CHECK
(
in_seq
.
data
()
&&
in_seq
.
getSequenceIds
().
data
());
CHECK_EQ
(
in_seq
.
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
in_seq
.
getSequenceIds
().
shape
().
ndims
(),
(
size_t
)
1
);
CHECK_EQ
(
out_seq
.
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
out_seq
.
getSequenceIds
().
shape
().
ndims
(),
(
size_t
)
1
);
CHECK_EQ
(
outputs
[
1
].
shape
().
ndims
(),
(
size_t
)
2
);
CHECK
(
outputs
[
0
].
data
()
&&
inputs
[
2
].
data
());
/// dim of input grad == dim of weight
CHECK_EQ
(
outputs
[
0
].
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
out_seq
.
shape
()[
1
],
outputs
[
1
].
shape
()[
1
]);
CHECK_EQ
(
inputs
[
0
].
shape
().
ndims
(),
(
size_t
)
2
);
/// input and output grad has the same batch_size
CHECK_EQ
(
inputs
[
1
].
shape
().
ndims
(),
(
size_t
)
2
);
CHECK_EQ
(
out_seq
.
shape
()[
0
],
in_seq
.
shape
()[
0
]);
CHECK_EQ
(
inputs
[
2
].
shape
().
ndims
(),
(
size_t
)
1
);
/// dim of output grad = dim of input grad * context_length
CHECK_EQ
(
in_seq
.
shape
()[
1
],
out_seq
.
shape
()[
1
]
*
context_length_
);
CHECK_EQ
(
out_seq
.
getArgType
(),
ADD_TO
);
CHECK_EQ
(
outputs
[
1
].
getArgType
(),
ADD_TO
);
/// dim of input == dim of weight
const
auto
seq_vec
=
in_seq
.
getSequenceIds
().
vector
<
int
,
Device
>
();
CHECK_EQ
(
inputs
[
0
].
shape
()[
1
],
inputs
[
1
].
shape
()[
1
]);
const
auto
out_grad_mat
=
in_seq
.
matrix
<
Device
>
();
/// input and output has the same batch_size
CHECK_EQ
(
inputs
[
0
].
shape
()[
0
],
outputs
[
0
].
shape
()[
0
]);
/// dim of output = dim of input * context_length
CHECK_EQ
(
outputs
[
0
].
shape
()[
1
],
inputs
[
0
].
shape
()[
1
]
*
context_length_
);
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
auto
out_grad_mat
=
outputs
[
0
].
matrix
<
Device
>
();
auto
in_grad_mat
=
auto
in_grad_mat
=
!
inputs
[
0
]
.
data
()
?
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
)
!
out_seq
.
data
()
?
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
)
:
inputs
[
0
]
.
matrix
<
Device
>
();
:
out_seq
.
matrix
<
Device
>
();
auto
w_grad_mat
=
!
in
puts
[
1
].
data
()
auto
w_grad_mat
=
!
out
puts
[
1
].
data
()
?
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
)
?
typename
Tensor
<
real
,
Device
>::
Matrix
(
nullptr
,
0
,
0
)
:
inputs
[
1
].
matrix
<
Device
>
();
:
outputs
[
1
].
matrix
<
Device
>
();
auto
seq_vec
=
inputs
[
2
].
vector
<
int
,
Device
>
();
ContextProjectionBackward
<
Device
>
(
out_grad_mat
,
ContextProjectionBackward
<
Device
>
(
out_grad_mat
,
in_grad_mat
,
in_grad_mat
,
w_grad_mat
,
w_grad_mat
,
...
@@ -238,11 +285,16 @@ private:
...
@@ -238,11 +285,16 @@ private:
size_t
total_pad_
;
size_t
total_pad_
;
};
};
#if 0
/**
/**
* \param inputs[0] input grad.
* Context Projection Backward Data Function
* \param inputs[1] input sequence.
* Update input layer grad
* \param outputs[0] output grad.
* input: sequence of output layer grad
* output: sequence of input layer grad
*
* \param outputs[0].matrix input layer grad, n * d
* \param outputs[0].vector start position sequence, n * 1
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vector start positon sequence, n * 1
*/
*/
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
class
ContextProjectionBackwardDataFunc
:
public
FunctionBase
{
class
ContextProjectionBackwardDataFunc
:
public
FunctionBase
{
...
@@ -252,32 +304,30 @@ public:
...
@@ -252,32 +304,30 @@ public:
context_start_
=
config
.
get
<
int
>
(
"context_start"
);
context_start_
=
config
.
get
<
int
>
(
"context_start"
);
}
}
void calc(const Arguments& inputs,
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const Arguments& outputs,
CHECK_EQ
(
1
,
static_cast
<
int
>
(
inputs
.
size
()));
const Arguments& inouts) override {
CHECK_EQ(2, static_cast<int>(inputs.size()));
CHECK_EQ
(
1
,
static_cast
<
int
>
(
outputs
.
size
()));
CHECK_EQ
(
1
,
static_cast
<
int
>
(
outputs
.
size
()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK
(
inputs
[
0
].
isSequenceArg
()
&&
outputs
[
0
].
isSequenceArg
())
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
<<
"SequenceArg required here"
;
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
const
auto
in_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
inputs
[
0
]);
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
const
auto
out_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
outputs
[
0
]);
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 1);
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK
(
in_seq
.
data
()
&&
out_seq
.
data
()
&&
in_seq
.
getSequenceIds
().
data
());
CHECK_EQ
(
static_cast
<
int
>
(
out_seq
.
shape
().
ndims
()),
2
);
CHECK_EQ
(
static_cast
<
int
>
(
in_seq
.
shape
().
ndims
()),
2
);
CHECK_EQ
(
static_cast
<
int
>
(
in_seq
.
getSequenceIds
().
shape
().
ndims
()),
1
);
/// output layer grad dim == input layer grad dim * context_length_
CHECK_EQ
(
in_seq
.
shape
().
ndims
(),
out_seq
.
shape
().
ndims
()
*
context_length_
);
/// input and output has the same batch_size
/// input and output has the same batch_size
CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]);
CHECK_EQ
(
in_seq
.
shape
()[
0
],
out_seq
.
shape
()[
0
]);
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ASSIGN_TO
);
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
const
auto
out_grad_mat
=
in_seq
.
matrix
<
Device
>
();
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
const
auto
seq_vec
=
in_seq
.
getSequenceIds
().
vector
<
int
,
Device
>
();
const auto in_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
auto
in_grad_mat
=
out_seq
.
matrix
<
Device
>
();
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[1].dims_[0], reinterpret_cast<int*>(inputs[1].getData()));
ContextProjectionBackwardData<Device>(out_grad_mat.get(),
ContextProjectionBackwardData
<
Device
>
(
in_grad_mat.get(),
out_grad_mat
,
in_grad_mat
,
seq_vec
,
context_length_
,
context_start_
);
seq_vec,
context_length_,
context_start_);
}
}
private:
private:
...
@@ -286,9 +336,14 @@ private:
...
@@ -286,9 +336,14 @@ private:
};
};
/**
/**
* \param inputs[0] weight grad.
* Context Projection Backward Weight Function
* \param inputs[1] input sequence.
* Update weight grad by backprop
* \param outputs[0] output grad.
* input: sequence of output layer grad
* output: weight grad
*
* \param outputs[0] weight grad, pad * d
* \param inputs[0].matrix output layer grad, n * (d * l)
* \param inputs[0].vecotr start positon sequence, n * 1
*/
*/
template
<
DeviceType
Device
>
template
<
DeviceType
Device
>
class
ContextProjectionBackwardWeightFunc
:
public
FunctionBase
{
class
ContextProjectionBackwardWeightFunc
:
public
FunctionBase
{
...
@@ -300,28 +355,25 @@ public:
...
@@ -300,28 +355,25 @@ public:
total_pad_
=
config
.
get
<
size_t
>
(
"total_pad"
);
total_pad_
=
config
.
get
<
size_t
>
(
"total_pad"
);
}
}
void calc(const Arguments& inputs,
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
const Arguments& outputs,
CHECK_EQ
(
1
,
static_cast
<
int
>
(
inputs
.
size
()));
const Arguments& inouts) override {
CHECK_EQ(2, static_cast<int>(inputs.size()));
CHECK_EQ
(
1
,
static_cast
<
int
>
(
outputs
.
size
()));
CHECK_EQ
(
1
,
static_cast
<
int
>
(
outputs
.
size
()));
CHECK_EQ(0, static_cast<int>(inouts.size()));
CHECK
(
inputs
[
0
].
isSequenceArg
())
<<
"SequenceArg required here"
;
const
auto
in_seq
=
dynamic_cast
<
const
SequenceArg
&>
(
inputs
[
0
]);
CHECK(inputs[0].getData() && outputs[0].getData() && inputs[1].getData());
CHECK
(
in_seq
.
data
()
&&
in_seq
.
getSequenceIds
().
data
()
&&
outputs
[
0
].
data
());
CHECK_EQ(static_cast<int>(outputs[0].dims_.size()), 2);
CHECK_EQ
(
static_cast
<
int
>
(
outputs
[
0
].
shape
().
ndims
()),
2
);
CHECK_EQ(static_cast<int>(inputs[0].dims_.size()), 2);
CHECK_EQ
(
static_cast
<
int
>
(
in_seq
.
shape
().
ndims
()),
2
);
CHECK_EQ(static_cast<int>(inputs[1].dims_.size()), 1);
CHECK_EQ
(
static_cast
<
int
>
(
in_seq
.
getSequenceIds
().
shape
().
ndims
()),
1
);
CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_);
CHECK_EQ
(
in_seq
.
shape
()[
0
],
outputs
[
0
].
shape
()[
0
]);
/// output layer grad dim == weight dim * context_length_
auto out_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
CHECK_EQ
(
in_seq
.
shape
()[
1
],
outputs
[
0
].
shape
()[
1
]
*
context_length_
);
outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]);
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
auto w_grad_mat = std::make_shared<typename MatrixT<Device>::type>(
inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]);
typename SequenceT<Device>::type seq_vec(
inputs[1].dims_[0], reinterpret_cast<int*>(inputs[1].getData()));
ContextProjectionBackwardWeight<Device>(out_grad_mat.get(),
const
auto
seq_vec
=
in_seq
.
getSequenceIds
().
vector
<
int
,
Device
>
();
w_grad_mat.get(),
const
auto
out_grad_mat
=
in_seq
.
matrix
<
Device
>
();
auto
w_grad_mat
=
outputs
[
0
].
matrix
<
Device
>
();
ContextProjectionBackwardWeight
<
Device
>
(
out_grad_mat
,
w_grad_mat
,
seq_vec
,
seq_vec
,
context_length_
,
context_length_
,
context_start_
,
context_start_
,
...
@@ -335,7 +387,6 @@ private:
...
@@ -335,7 +387,6 @@ private:
size_t
begin_pad_
;
size_t
begin_pad_
;
size_t
total_pad_
;
size_t
total_pad_
;
};
};
#endif
REGISTER_TYPED_FUNC
(
ContextProjectionForward
,
REGISTER_TYPED_FUNC
(
ContextProjectionForward
,
CPU
,
CPU
,
...
@@ -350,7 +401,6 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
...
@@ -350,7 +401,6 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC
(
ContextProjectionBackward
,
REGISTER_TYPED_FUNC
(
ContextProjectionBackward
,
GPU
,
GPU
,
ContextProjectionBackwardFunc
);
ContextProjectionBackwardFunc
);
#if 0
REGISTER_TYPED_FUNC
(
ContextProjectionBackwardData
,
REGISTER_TYPED_FUNC
(
ContextProjectionBackwardData
,
GPU
,
GPU
,
ContextProjectionBackwardDataFunc
);
ContextProjectionBackwardDataFunc
);
...
@@ -358,5 +408,4 @@ REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
...
@@ -358,5 +408,4 @@ REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
GPU
,
GPU
,
ContextProjectionBackwardWeightFunc
);
ContextProjectionBackwardWeightFunc
);
#endif
#endif
#endif
}
// namespace paddle
}
// namespace paddle
paddle/function/ContextProjectionOp.h
浏览文件 @
0ae5ac16
...
@@ -21,14 +21,14 @@ namespace paddle {
...
@@ -21,14 +21,14 @@ namespace paddle {
/**
/**
* \brief Context Projection Forward.
* \brief Context Projection Forward.
*
*
* \param[out] outputs output data.
* \param[
in/
out] outputs output data.
* \param[in] input input data.
* \param[in]
input input data.
* \param[in] weight input weight.
* \param[in]
weight input weight.
* \param[in] sequence input data.
* \param[in]
sequence input data.
* \param[in] context_length consecutive rows for concatenation.
* \param[in]
context_length consecutive rows for concatenation.
* \param[in] context_start context start position.
* \param[in]
context_start context start position.
* \param[in] begin_pad begining pad position.
* \param[in]
begin_pad begining pad position.
* \param[in] is_padding whether padding 0 or not.
* \param[in]
is_padding whether padding 0 or not.
*
*
*/
*/
template
<
DeviceType
DType
>
template
<
DeviceType
DType
>
...
@@ -56,7 +56,7 @@ void ContextProjectionForward(
...
@@ -56,7 +56,7 @@ void ContextProjectionForward(
*/
*/
template
<
DeviceType
DType
>
template
<
DeviceType
DType
>
void
ContextProjectionBackward
(
void
ContextProjectionBackward
(
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
const
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
in_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
in_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
w_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
w_grad
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
seq_vec
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
seq_vec
,
...
@@ -68,7 +68,7 @@ void ContextProjectionBackward(
...
@@ -68,7 +68,7 @@ void ContextProjectionBackward(
template
<
DeviceType
DType
>
template
<
DeviceType
DType
>
void
ContextProjectionBackwardData
(
void
ContextProjectionBackwardData
(
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
const
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
in_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
in_grad
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
sequence
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
sequence
,
size_t
context_length
,
size_t
context_length
,
...
@@ -76,7 +76,7 @@ void ContextProjectionBackwardData(
...
@@ -76,7 +76,7 @@ void ContextProjectionBackwardData(
template
<
DeviceType
DType
>
template
<
DeviceType
DType
>
void
ContextProjectionBackwardWeight
(
void
ContextProjectionBackwardWeight
(
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
const
typename
Tensor
<
real
,
DType
>::
Matrix
&
out_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
w_grad
,
typename
Tensor
<
real
,
DType
>::
Matrix
&
w_grad
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
seq_vec
,
const
typename
Tensor
<
int
,
DType
>::
Vector
&
seq_vec
,
size_t
context_length
,
size_t
context_length
,
...
...
paddle/function/ContextProjectionOpGpu.cu
浏览文件 @
0ae5ac16
...
@@ -138,10 +138,10 @@ void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
...
@@ -138,10 +138,10 @@ void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
begin_pad
);
begin_pad
);
}
}
__global__
void
KeContextProjectionBackwardData
(
real
*
out_grad
,
__global__
void
KeContextProjectionBackwardData
(
const
real
*
out_grad
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
in_grad
,
real
*
in_grad
,
in
t
input_dim
,
size_
t
input_dim
,
int
context_length
,
int
context_length
,
int
context_start
)
{
int
context_start
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
...
@@ -152,7 +152,8 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
...
@@ -152,7 +152,8 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
real
value
=
0
;
real
value
=
0
;
int
instances
=
seq_end
-
seq_start
+
context_length
-
1
;
int
instances
=
seq_end
-
seq_start
+
context_length
-
1
;
out_grad
+=
seq_start
*
input_dim
*
context_length
;
auto
out
=
const_cast
<
real
*>
(
out_grad
);
out
+=
seq_start
*
input_dim
*
context_length
;
in_grad
+=
seq_start
*
input_dim
;
in_grad
+=
seq_start
*
input_dim
;
for
(
int
k
=
0
;
k
<=
input_dim
/
block_size
;
k
++
)
{
for
(
int
k
=
0
;
k
<=
input_dim
/
block_size
;
k
++
)
{
if
(
idx
<
input_dim
)
{
if
(
idx
<
input_dim
)
{
...
@@ -169,7 +170,7 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
...
@@ -169,7 +170,7 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
int
outx
=
(
i
-
context_length
)
<
0
?
i
:
(
context_length
-
1
);
int
outx
=
(
i
-
context_length
)
<
0
?
i
:
(
context_length
-
1
);
int
outy
=
(
i
-
context_length
)
<
0
?
0
:
(
i
-
(
context_length
-
1
));
int
outy
=
(
i
-
context_length
)
<
0
?
0
:
(
i
-
(
context_length
-
1
));
real
*
output_r
=
real
*
output_r
=
out
_grad
+
outy
*
input_dim
*
context_length
+
outx
*
input_dim
;
out
+
outy
*
input_dim
*
context_length
+
outx
*
input_dim
;
for
(
int
j
=
outy
;
j
<
seq_end
-
seq_start
;
j
++
)
{
for
(
int
j
=
outy
;
j
<
seq_end
-
seq_start
;
j
++
)
{
value
+=
output_r
[
idx
];
value
+=
output_r
[
idx
];
if
(
j
-
outy
==
outx
)
break
;
if
(
j
-
outy
==
outx
)
break
;
...
@@ -194,7 +195,7 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
...
@@ -194,7 +195,7 @@ __global__ void KeContextProjectionBackwardData(real* out_grad,
* @param[in] context_start context start.
* @param[in] context_start context start.
*
*
*/
*/
void
hl_context_projection_backward_data
(
real
*
out_grad
,
void
hl_context_projection_backward_data
(
const
real
*
out_grad
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
input_grad
,
real
*
input_grad
,
size_t
num_sequences
,
size_t
num_sequences
,
...
@@ -216,7 +217,7 @@ void hl_context_projection_backward_data(real* out_grad,
...
@@ -216,7 +217,7 @@ void hl_context_projection_backward_data(real* out_grad,
}
}
template
<
>
template
<
>
void
ContextProjectionBackwardData
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out_grad
,
void
ContextProjectionBackwardData
<
DEVICE_TYPE_GPU
>
(
const
GpuMatrix
&
out_grad
,
GpuMatrix
&
in_grad
,
GpuMatrix
&
in_grad
,
const
GpuIVector
&
sequence
,
const
GpuIVector
&
sequence
,
size_t
context_length
,
size_t
context_length
,
...
@@ -231,7 +232,7 @@ void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
...
@@ -231,7 +232,7 @@ void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
}
}
template
<
int
THREADS_X
,
int
THREADS_Y
>
template
<
int
THREADS_X
,
int
THREADS_Y
>
__global__
void
KeContextProjectionBackwardWeight
(
real
*
out_grad
,
__global__
void
KeContextProjectionBackwardWeight
(
const
real
*
out_grad
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
w_grad
,
real
*
w_grad
,
int
num_sequences
,
int
num_sequences
,
...
@@ -254,7 +255,8 @@ __global__ void KeContextProjectionBackwardWeight(real* out_grad,
...
@@ -254,7 +255,8 @@ __global__ void KeContextProjectionBackwardWeight(real* out_grad,
for
(
int
seqId
=
idy
;
seqId
<
num_sequences
;
seqId
+=
THREADS_Y
)
{
for
(
int
seqId
=
idy
;
seqId
<
num_sequences
;
seqId
+=
THREADS_Y
)
{
int
seq_start
=
sequence
[
seqId
];
int
seq_start
=
sequence
[
seqId
];
int
seq_end
=
sequence
[
seqId
+
1
];
int
seq_end
=
sequence
[
seqId
+
1
];
output_r
=
out_grad
+
seq_start
*
w_dim
*
context_length
;
output_r
=
const_cast
<
real
*>
(
out_grad
)
+
seq_start
*
w_dim
*
context_length
;
if
(
context_start
<
0
)
{
if
(
context_start
<
0
)
{
if
(
padId
+
context_start
<
0
)
{
if
(
padId
+
context_start
<
0
)
{
...
@@ -318,7 +320,7 @@ __global__ void KeContextProjectionBackwardWeight(real* out_grad,
...
@@ -318,7 +320,7 @@ __global__ void KeContextProjectionBackwardWeight(real* out_grad,
* beginning.
* beginning.
*
*
*/
*/
void
hl_context_projection_backward_weight
(
real
*
out_grad
,
void
hl_context_projection_backward_weight
(
const
real
*
out_grad
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
w_grad
,
real
*
w_grad
,
size_t
num_sequences
,
size_t
num_sequences
,
...
@@ -346,7 +348,7 @@ void hl_context_projection_backward_weight(real* out_grad,
...
@@ -346,7 +348,7 @@ void hl_context_projection_backward_weight(real* out_grad,
template
<
>
template
<
>
void
ContextProjectionBackwardWeight
<
DEVICE_TYPE_GPU
>
(
void
ContextProjectionBackwardWeight
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out_grad
,
const
GpuMatrix
&
out_grad
,
GpuMatrix
&
w_grad
,
GpuMatrix
&
w_grad
,
const
GpuIVector
&
seq_vec
,
const
GpuIVector
&
seq_vec
,
size_t
context_length
,
size_t
context_length
,
...
@@ -365,7 +367,7 @@ void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
...
@@ -365,7 +367,7 @@ void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
}
}
template
<
>
template
<
>
void
ContextProjectionBackward
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out_grad
,
void
ContextProjectionBackward
<
DEVICE_TYPE_GPU
>
(
const
GpuMatrix
&
out_grad
,
GpuMatrix
&
in_grad
,
GpuMatrix
&
in_grad
,
GpuMatrix
&
w_grad
,
GpuMatrix
&
w_grad
,
const
GpuIVector
&
sequence
,
const
GpuIVector
&
sequence
,
...
...
paddle/function/ContextProjectionOpTest.cpp
浏览文件 @
0ae5ac16
...
@@ -56,22 +56,25 @@ void testMatrixProjectionForward(int context_start,
...
@@ -56,22 +56,25 @@ void testMatrixProjectionForward(int context_start,
cpu_out
.
randomizeUniform
();
cpu_out
.
randomizeUniform
();
gpu_out
.
copyFrom
(
cpu_out
);
gpu_out
.
copyFrom
(
cpu_out
);
compare
.
getCpuFunction
()
->
calc
(
BufferArgs
cpu_inputs
;
{
Tensor
(
cpu_in
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
BufferArgs
cpu_outputs
;
Tensor
(
cpu_weight
?
cpu_weight
->
getData
()
:
nullptr
,
cpu_inputs
.
addArg
(
cpu_in
,
*
cpu_seq
);
Dims
{
pad
,
input_dim
}),
if
(
cpu_weight
)
{
Tensor
(
reinterpret_cast
<
real
*>
(
cpu_seq
->
getData
()),
cpu_inputs
.
addArg
(
*
cpu_weight
,
*
cpu_seq
);
Dims
{
cpu_seq
->
getSize
()})},
}
{
Tensor
(
cpu_out
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
cpu_outputs
.
addArg
(
cpu_out
,
*
cpu_seq
,
ADD_TO
);
{});
compare
.
getGpuFunction
()
->
calc
(
compare
.
getCpuFunction
()
->
calc
(
cpu_inputs
,
cpu_outputs
);
{
Tensor
(
gpu_in
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
Tensor
(
gpu_weight
?
gpu_weight
->
getData
()
:
nullptr
,
BufferArgs
gpu_inputs
;
Dims
{
pad
,
input_dim
}),
BufferArgs
gpu_outputs
;
Tensor
(
reinterpret_cast
<
real
*>
(
gpu_seq
->
getData
()),
gpu_inputs
.
addArg
(
gpu_in
,
*
gpu_seq
);
Dims
{
gpu_seq
->
getSize
()})},
if
(
gpu_weight
)
{
{
Tensor
(
gpu_out
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
gpu_inputs
.
addArg
(
*
gpu_weight
,
*
gpu_seq
);
{});
}
gpu_outputs
.
addArg
(
gpu_out
,
*
gpu_seq
,
ADD_TO
);
compare
.
getGpuFunction
()
->
calc
(
gpu_inputs
,
gpu_outputs
);
autotest
::
TensorCheckEqual
(
cpu_out
,
gpu_out
);
autotest
::
TensorCheckEqual
(
cpu_out
,
gpu_out
);
}
}
...
@@ -117,25 +120,23 @@ void testMatrixProjectionBackward(int context_start,
...
@@ -117,25 +120,23 @@ void testMatrixProjectionBackward(int context_start,
gpu_w_grad
->
copyFrom
(
*
cpu_w_grad
);
gpu_w_grad
->
copyFrom
(
*
cpu_w_grad
);
}
}
compare
.
getCpuFunction
()
->
calc
(
BufferArgs
cpu_inputs
;
{
Tensor
(
cpu_in_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
BufferArgs
cpu_outputs
;
Tensor
(
cpu_w_grad
?
cpu_w_grad
->
getData
()
:
nullptr
,
cpu_inputs
.
addArg
(
cpu_out_grad
,
*
cpu_seq
);
Dims
{
pad
,
input_dim
}),
cpu_outputs
.
addArg
(
cpu_in_grad
,
*
cpu_seq
,
ADD_TO
);
Tensor
(
reinterpret_cast
<
real
*>
(
cpu_seq
->
getData
()),
cpu_outputs
.
addArg
(
Dims
{
cpu_seq
->
getSize
()})},
cpu_w_grad
?
*
cpu_w_grad
:
CpuMatrix
(
nullptr
,
0
,
input_dim
),
ADD_TO
);
{
Tensor
(
cpu_out_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
compare
.
getCpuFunction
()
->
calc
(
cpu_inputs
,
cpu_outputs
);
{});
BufferArgs
gpu_inputs
;
compare
.
getGpuFunction
()
->
calc
(
BufferArgs
gpu_outputs
;
{
Tensor
(
gpu_in_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
gpu_inputs
.
addArg
(
gpu_out_grad
,
*
gpu_seq
);
Tensor
(
gpu_w_grad
?
gpu_w_grad
->
getData
()
:
nullptr
,
gpu_outputs
.
addArg
(
gpu_in_grad
,
*
gpu_seq
,
ADD_TO
);
Dims
{
pad
,
input_dim
}),
gpu_outputs
.
addArg
(
Tensor
(
reinterpret_cast
<
real
*>
(
gpu_seq
->
getData
()),
gpu_w_grad
?
*
gpu_w_grad
:
GpuMatrix
(
nullptr
,
0
,
input_dim
),
ADD_TO
);
Dims
{
gpu_seq
->
getSize
()})},
{
Tensor
(
gpu_out_grad
.
getData
(),
compare
.
getGpuFunction
()
->
calc
(
gpu_inputs
,
gpu_outputs
);
Dims
{
batch_size
,
input_dim
*
context_length
})},
{});
autotest
::
TensorCheckErr
(
cpu_in_grad
,
gpu_in_grad
);
autotest
::
TensorCheckErr
(
cpu_in_grad
,
gpu_in_grad
);
if
(
is_padding
)
{
if
(
is_padding
)
{
...
...
paddle/function/Function.cpp
浏览文件 @
0ae5ac16
...
@@ -90,6 +90,12 @@ void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
...
@@ -90,6 +90,12 @@ void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) {
args_
.
push_back
(
std
::
make_shared
<
SparseMatrixArg
>
(
arg
,
argType
));
args_
.
push_back
(
std
::
make_shared
<
SparseMatrixArg
>
(
arg
,
argType
));
}
}
void
BufferArgs
::
addArg
(
const
Matrix
&
matrix
,
const
IVector
&
vector
,
ArgType
argType
)
{
args_
.
push_back
(
std
::
make_shared
<
SequenceArg
>
(
matrix
,
vector
,
argType
));
}
ClassRegistrar
<
FunctionBase
>
FunctionBase
::
funcRegistrar_
;
ClassRegistrar
<
FunctionBase
>
FunctionBase
::
funcRegistrar_
;
}
// namespace paddle
}
// namespace paddle
paddle/function/Function.h
浏览文件 @
0ae5ac16
...
@@ -77,6 +77,10 @@ public:
...
@@ -77,6 +77,10 @@ public:
void
addArg
(
const
CpuSparseMatrix
&
arg
,
ArgType
argType
=
UNSPECIFIED
);
void
addArg
(
const
CpuSparseMatrix
&
arg
,
ArgType
argType
=
UNSPECIFIED
);
void
addArg
(
const
GpuSparseMatrix
&
arg
,
ArgType
argType
=
UNSPECIFIED
);
void
addArg
(
const
GpuSparseMatrix
&
arg
,
ArgType
argType
=
UNSPECIFIED
);
void
addArg
(
const
Matrix
&
matrix
,
const
IVector
&
vector
,
ArgType
argType
=
UNSPECIFIED
);
// get argument
// get argument
const
BufferArg
&
operator
[](
size_t
num
)
const
{
const
BufferArg
&
operator
[](
size_t
num
)
const
{
CHECK_LT
(
num
,
args_
.
size
());
CHECK_LT
(
num
,
args_
.
size
());
...
...
paddle/function/FunctionTest.h
浏览文件 @
0ae5ac16
...
@@ -27,66 +27,28 @@ public:
...
@@ -27,66 +27,28 @@ public:
gpu
->
init
(
config
);
gpu
->
init
(
config
);
}
}
void
cmpWithArg
(
const
Argument
s
&
inputs
,
void
cmpWithArg
(
const
BufferArg
s
&
inputs
,
const
Argument
s
&
outputs
,
const
BufferArg
s
&
outputs
,
const
Argument
s
&
inouts
)
{
const
BufferArg
s
&
inouts
)
{
// init cpu and gpu arguments
// init cpu and gpu arguments
auto
initArgs
=
[
=
](
auto
initArgs
=
[
=
](
Arguments
&
cpuArgs
,
Arguments
&
gpuArgs
,
const
Arguments
&
inArgs
)
{
BufferArgs
&
cpuArgs
,
BufferArgs
&
gpuArgs
,
const
BufferArgs
&
inArgs
)
{
for
(
const
auto
arg
:
inArgs
)
{
/// leave it empty to pass the compile of ContextProjectionTest
size_t
size
=
sizeof
(
real
);
/// Daoyuan is working on FunctionTest
for
(
const
auto
dim
:
arg
.
dims_
)
{
/// and I will further merge with it
size
*=
dim
;
}
if
(
arg
.
getData
())
{
// todo(tianbing), waste unnecessary mem here
cpuMemory
.
emplace_back
(
std
::
make_shared
<
CpuMemoryHandle
>
(
size
));
gpuMemory
.
emplace_back
(
std
::
make_shared
<
GpuMemoryHandle
>
(
size
));
cpuArgs
.
emplace_back
(
Tensor
((
real
*
)
arg
.
getData
(),
arg
.
dims_
));
gpuArgs
.
emplace_back
(
Tensor
((
real
*
)
arg
.
getData
(),
arg
.
dims_
));
// already init outside
}
else
{
cpuMemory
.
emplace_back
(
std
::
make_shared
<
CpuMemoryHandle
>
(
size
));
gpuMemory
.
emplace_back
(
std
::
make_shared
<
GpuMemoryHandle
>
(
size
));
cpuArgs
.
emplace_back
(
Tensor
((
real
*
)
cpuMemory
.
back
()
->
getBuf
(),
arg
.
dims_
));
gpuArgs
.
emplace_back
(
Tensor
((
real
*
)
gpuMemory
.
back
()
->
getBuf
(),
arg
.
dims_
));
// will use an api to refactor this code.
CpuVector
cpuVector
(
size
/
sizeof
(
real
),
(
real
*
)
cpuArgs
.
back
().
getData
());
GpuVector
gpuVector
(
size
/
sizeof
(
real
),
(
real
*
)
gpuArgs
.
back
().
getData
());
cpuVector
.
uniform
(
0.001
,
1
);
gpuVector
.
copyFrom
(
cpuVector
);
}
}
};
};
initArgs
(
cpuInputs
,
gpuInputs
,
inputs
);
initArgs
(
cpuInputs
,
gpuInputs
,
inputs
);
initArgs
(
cpuOutputs
,
gpuOutputs
,
outputs
);
initArgs
(
cpuOutputs
,
gpuOutputs
,
outputs
);
initArgs
(
cpuInouts
,
gpuInouts
,
inouts
);
// function calculate
// function calculate
cpu
->
calc
(
cpuInputs
,
cpuOutputs
,
cpuInouts
);
cpu
->
calc
(
cpuInputs
,
cpuOutputs
);
gpu
->
calc
(
gpuInputs
,
gpuOutputs
,
gpuInouts
);
gpu
->
calc
(
gpuInputs
,
gpuOutputs
);
// check outputs and inouts
// check outputs and inouts
auto
checkArgs
=
[
=
](
const
Arguments
&
cpuArgs
,
const
Arguments
&
gpuArgs
)
{
auto
checkArgs
=
[
=
](
const
BufferArgs
&
cpuArgs
,
const
BufferArgs
&
gpuArgs
)
{
for
(
size_t
i
=
0
;
i
<
cpuArgs
.
size
();
i
++
)
{
/// leave it open
auto
cpu
=
cpuArgs
[
i
];
auto
gpu
=
gpuArgs
[
i
];
size_t
size
=
1
;
for
(
auto
dim
:
cpu
.
dims_
)
{
size
*=
dim
;
}
CpuVector
cpuVector
(
size
,
(
real
*
)
cpu
.
getData
());
GpuVector
gpuVector
(
size
,
(
real
*
)
gpu
.
getData
());
autotest
::
TensorCheckErr
(
cpuVector
,
gpuVector
);
}
};
};
checkArgs
(
cpuOutputs
,
gpuOutputs
);
checkArgs
(
cpuOutputs
,
gpuOutputs
);
checkArgs
(
cpuInouts
,
gpuInouts
);
}
}
std
::
shared_ptr
<
FunctionBase
>
getCpuFunction
()
const
{
return
cpu
;
}
std
::
shared_ptr
<
FunctionBase
>
getCpuFunction
()
const
{
return
cpu
;
}
...
@@ -98,12 +60,12 @@ protected:
...
@@ -98,12 +60,12 @@ protected:
std
::
shared_ptr
<
FunctionBase
>
gpu
;
std
::
shared_ptr
<
FunctionBase
>
gpu
;
std
::
vector
<
CpuMemHandlePtr
>
cpuMemory
;
std
::
vector
<
CpuMemHandlePtr
>
cpuMemory
;
std
::
vector
<
GpuMemHandlePtr
>
gpuMemory
;
std
::
vector
<
GpuMemHandlePtr
>
gpuMemory
;
Argument
s
cpuInputs
;
BufferArg
s
cpuInputs
;
Argument
s
cpuOutputs
;
BufferArg
s
cpuOutputs
;
Argument
s
cpuInouts
;
BufferArg
s
cpuInouts
;
Argument
s
gpuInputs
;
BufferArg
s
gpuInputs
;
Argument
s
gpuOutputs
;
BufferArg
s
gpuOutputs
;
Argument
s
gpuInouts
;
BufferArg
s
gpuInouts
;
};
};
}
// namespace paddle
}
// namespace paddle
paddle/gserver/layers/ContextProjection.cpp
浏览文件 @
0ae5ac16
...
@@ -118,16 +118,15 @@ void ContextProjection::forward() {
...
@@ -118,16 +118,15 @@ void ContextProjection::forward() {
/// first use state_, otherwise use weight_(padding false === w nullptr)
/// first use state_, otherwise use weight_(padding false === w nullptr)
auto
w_ptr
=
auto
w_ptr
=
state_
?
state_
.
get
()
:
is_padding
?
weight_
->
getW
().
get
()
:
nullptr
;
state_
?
state_
.
get
()
:
is_padding
?
weight_
->
getW
().
get
()
:
nullptr
;
auto
start_pos
=
in_
->
sequenceStartPositions
;
const
auto
start_pos
=
in_
->
sequenceStartPositions
->
getVector
(
useGpu_
);
BufferArgs
inputs
;
BufferArgs
inputs
;
BufferArgs
outputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
*
in_
->
value
);
inputs
.
addArg
(
*
in_
->
value
,
*
start_pos
);
i
nputs
.
addArg
(
CpuMatrix
(
w_ptr
?
w_ptr
->
getData
()
:
nullptr
,
i
f
(
w_ptr
)
{
w_ptr
?
w_ptr
->
getHeight
()
:
0
,
inputs
.
addArg
(
CpuMatrix
(
w_ptr
->
getData
(),
w_ptr
->
getHeight
(),
input_dim
)
,
input_dim
)
);
*
start_pos
);
inputs
.
addArg
(
*
in_
->
sequenceStartPositions
->
getVector
(
useGpu_
));
}
outputs
.
addArg
(
*
out_
->
value
,
ADD_TO
);
outputs
.
addArg
(
*
out_
->
value
,
*
start_pos
,
ADD_TO
);
forward_
[
0
]
->
calc
(
inputs
,
outputs
);
forward_
[
0
]
->
calc
(
inputs
,
outputs
);
if
(
state_
&&
config_
.
context_start
()
<
0
)
{
if
(
state_
&&
config_
.
context_start
()
<
0
)
{
...
@@ -166,13 +165,16 @@ void ContextProjection::backward(const UpdateCallback& callback) {
...
@@ -166,13 +165,16 @@ void ContextProjection::backward(const UpdateCallback& callback) {
BufferArgs
inputs
;
BufferArgs
inputs
;
BufferArgs
outputs
;
BufferArgs
outputs
;
inputs
.
addArg
(
CpuMatrix
(
inputs
.
addArg
(
*
out_
->
grad
,
*
in_
->
sequenceStartPositions
->
getVector
(
useGpu_
));
in_
->
grad
?
in_
->
grad
->
getData
()
:
nullptr
,
batch_size
,
input_dim
));
outputs
.
addArg
(
inputs
.
addArg
(
CpuMatrix
(
w_ptr
?
w_ptr
->
getData
()
:
nullptr
,
CpuMatrix
(
w_ptr
?
w_ptr
->
getHeight
()
:
0
,
in_
->
grad
?
in_
->
grad
->
getData
()
:
nullptr
,
batch_size
,
input_dim
),
input_dim
));
*
in_
->
sequenceStartPositions
->
getVector
(
useGpu_
),
inputs
.
addArg
(
*
in_
->
sequenceStartPositions
->
getVector
(
useGpu_
));
ADD_TO
);
outputs
.
addArg
(
*
out_
->
grad
,
ADD_TO
);
outputs
.
addArg
(
CpuMatrix
(
w_ptr
?
w_ptr
->
getData
()
:
nullptr
,
w_ptr
?
w_ptr
->
getHeight
()
:
0
,
input_dim
),
ADD_TO
);
backward_
[
0
]
->
calc
(
inputs
,
outputs
);
backward_
[
0
]
->
calc
(
inputs
,
outputs
);
if
(
config_
.
trainable_padding
())
{
if
(
config_
.
trainable_padding
())
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录