Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
590ecba3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
590ecba3
编写于
12月 27, 2016
作者:
X
xutianbing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ContextProjectionBackward, ContextProjectionBackwardData, ContextProjectionBackwardWeightw
上级
838ef366
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
548 addition
and
6 deletion
+548
-6
paddle/function/CMakeLists.txt
paddle/function/CMakeLists.txt
+0
-1
paddle/function/context_projection_op.cpp
paddle/function/context_projection_op.cpp
+196
-1
paddle/function/context_projection_op.h
paddle/function/context_projection_op.h
+43
-3
paddle/function/context_projection_op_gpu.cu
paddle/function/context_projection_op_gpu.cu
+210
-0
paddle/function/context_projection_op_test.cpp
paddle/function/context_projection_op_test.cpp
+99
-1
未找到文件。
paddle/function/CMakeLists.txt
浏览文件 @
590ecba3
...
...
@@ -19,7 +19,6 @@ if(WITH_TESTING)
add_simple_unittest
(
CrossMapNormalOpTest
)
add_unittest
(
ContextProjectionOpTest
ContextProjectionOpTest.cpp
ContextProjectionOpGpu.cu
../gserver/tests/TestUtil.cpp
)
endif
()
endif
()
...
...
paddle/function/context_projection_op.cpp
浏览文件 @
590ecba3
...
...
@@ -41,7 +41,7 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(Tensor& output,
!
weight
.
getData
()
?
nullptr
:
std
::
make_shared
<
CpuMatrix
>
(
weight
.
getData
(),
weight
.
dims_
[
0
],
inpu
t
.
dims_
[
1
]);
weight
.
getData
(),
weight
.
dims_
[
0
],
weigh
t
.
dims_
[
1
]);
CpuIVector
seq_vec
(
sequence
.
dims_
[
0
],
reinterpret_cast
<
int
*>
(
sequence
.
getData
()));
CHECK_EQ
(
out_mat
->
getWidth
(),
in_mat
->
getWidth
()
*
context_length
);
...
...
@@ -125,12 +125,207 @@ private:
bool
is_padding_
;
};
template
<
>
void
ContextProjectionBackward
<
DEVICE_TYPE_CPU
>
(
Tensor
&
out_grad
,
const
Tensor
&
in_grad
,
const
Tensor
&
w_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
,
size_t
begin_pad
,
bool
is_padding
)
{
CHECK
(
out_grad
.
getData
()
&&
sequence
.
getData
());
CHECK_EQ
(
out_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
in_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
w_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
sequence
.
dims_
.
size
(),
1
);
auto
out_grad_mat
=
std
::
make_shared
<
CpuMatrix
>
(
out_grad
.
getData
(),
out_grad
.
dims_
[
0
],
out_grad
.
dims_
[
1
]);
const
auto
in_grad_mat
=
!
in_grad
.
getData
()
?
nullptr
:
std
::
make_shared
<
CpuMatrix
>
(
in_grad
.
getData
(),
in_grad
.
dims_
[
0
],
in_grad
.
dims_
[
1
]);
const
auto
w_grad_mat
=
!
w_grad
.
getData
()
?
nullptr
:
std
::
make_shared
<
CpuMatrix
>
(
w_grad
.
getData
(),
w_grad
.
dims_
[
0
],
w_grad
.
dims_
[
1
]);
CpuIVector
seq_vec
(
sequence
.
dims_
[
0
],
reinterpret_cast
<
int
*>
(
sequence
.
getData
()));
CHECK_EQ
(
out_grad_mat
->
getWidth
(),
in_grad_mat
->
getWidth
()
*
context_length
);
size_t
input_dim
=
in_grad_mat
?
in_grad_mat
->
getWidth
()
:
w_grad_mat
?
w_grad_mat
->
getWidth
()
:
0
;
CHECK_EQ
(
out_grad_mat
->
getWidth
(),
input_dim
*
context_length
);
const
int
*
starts
=
seq_vec
.
getData
();
size_t
num_sequences
=
seq_vec
.
getSize
()
-
1
;
for
(
size_t
i
=
0
;
i
<
num_sequences
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
context_length
;
++
j
)
{
int
begin
=
starts
[
i
]
+
context_start
+
j
;
int
end
=
starts
[
i
+
1
]
+
context_start
+
j
;
int
dst_begin
=
starts
[
i
];
int
dst_end
=
starts
[
i
+
1
];
if
(
begin
<
starts
[
i
])
{
int64_t
pad_size
=
std
::
min
(
starts
[
i
]
-
begin
,
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
is_padding
&&
w_grad_mat
)
{
MatrixPtr
mat
=
out_grad_mat
->
subMatrix
(
starts
[
i
],
pad_size
);
MatrixPtr
sub
=
w_grad_mat
->
subMatrix
(
j
,
pad_size
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
}
dst_begin
=
starts
[
i
]
+
pad_size
;
begin
=
starts
[
i
];
}
if
(
end
>
starts
[
i
+
1
])
{
int64_t
pad_size
=
std
::
min
(
end
-
starts
[
i
+
1
],
starts
[
i
+
1
]
-
starts
[
i
]);
if
(
is_padding
&&
w_grad_mat
)
{
MatrixPtr
mat
=
out_grad_mat
->
subMatrix
(
starts
[
i
+
1
]
-
pad_size
,
pad_size
);
MatrixPtr
sub
=
w_grad_mat
->
subMatrix
(
begin_pad
+
context_start
+
j
-
pad_size
,
pad_size
);
sub
->
addAtOffset
(
*
mat
,
j
*
input_dim
);
}
dst_end
=
starts
[
i
+
1
]
-
pad_size
;
end
=
starts
[
i
+
1
];
}
if
(
end
<=
begin
)
continue
;
if
(
!
in_grad_mat
)
continue
;
MatrixPtr
src
=
in_grad_mat
->
subMatrix
(
begin
,
end
-
begin
);
MatrixPtr
dst
=
out_grad_mat
->
subMatrix
(
dst_begin
,
dst_end
-
dst_begin
);
src
->
addAtOffset
(
*
dst
,
j
*
input_dim
);
}
}
}
/**
* \param inputs[0] input value.
* \param inputs[1] input weight.
* \param inputs[2] input sequence.
* \param outputs[0] output value.
*/
template
<
DeviceType
Device
>
class
ContextProjectionBackwardFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
context_length_
=
config
.
get
<
size_t
>
(
"context_length"
);
context_start_
=
config
.
get
<
int
>
(
"context_start"
);
begin_pad_
=
config
.
get
<
size_t
>
(
"begin_pad"
);
is_padding_
=
config
.
get
<
bool
>
(
"is_padding"
);
}
void
calc
(
const
Arguments
&
inputs
,
const
Arguments
&
outputs
,
const
Arguments
&
inouts
)
override
{
CHECK_EQ
(
3
,
inputs
.
size
());
CHECK_EQ
(
1
,
outputs
.
size
());
CHECK_EQ
(
0
,
inouts
.
size
());
ContextProjectionBackward
<
Device
>
((
Tensor
&
)
outputs
[
0
],
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
context_length_
,
context_start_
,
begin_pad_
,
is_padding_
);
}
private:
size_t
context_length_
;
int
context_start_
;
size_t
begin_pad_
;
bool
is_padding_
;
};
/**
* \param inputs[0] input grad.
* \param inputs[1] input sequence.
* \param outputs[0] output grad.
*/
template
<
DeviceType
Device
>
class
ContextProjectionBackwardDataFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
context_length_
=
config
.
get
<
size_t
>
(
"context_length"
);
context_start_
=
config
.
get
<
int
>
(
"context_start"
);
}
void
calc
(
const
Arguments
&
inputs
,
const
Arguments
&
outputs
,
const
Arguments
&
inouts
)
override
{
CHECK_EQ
(
2
,
inputs
.
size
());
CHECK_EQ
(
1
,
outputs
.
size
());
CHECK_EQ
(
0
,
inouts
.
size
());
ContextProjectionBackwardData
<
Device
>
((
Tensor
&
)
outputs
[
0
],
(
Tensor
&
)
inputs
[
0
],
inputs
[
1
],
context_length_
,
context_start_
);
}
private:
size_t
context_length_
;
int
context_start_
;
};
/**
* \param inputs[0] weight grad.
* \param inputs[1] input sequence.
* \param outputs[0] output grad.
*/
template
<
DeviceType
Device
>
class
ContextProjectionBackwardWeightFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
context_length_
=
config
.
get
<
size_t
>
(
"context_length"
);
context_start_
=
config
.
get
<
int
>
(
"context_start"
);
begin_pad_
=
config
.
get
<
size_t
>
(
"begin_pad"
);
total_pad_
=
config
.
get
<
size_t
>
(
"total_pad"
);
}
void
calc
(
const
Arguments
&
inputs
,
const
Arguments
&
outputs
,
const
Arguments
&
inouts
)
override
{
CHECK_EQ
(
2
,
inputs
.
size
());
CHECK_EQ
(
1
,
outputs
.
size
());
CHECK_EQ
(
0
,
inouts
.
size
());
ContextProjectionBackwardWeight
<
Device
>
((
Tensor
&
)
outputs
[
0
],
(
Tensor
&
)
inputs
[
0
],
inputs
[
1
],
context_length_
,
context_start_
,
total_pad_
,
begin_pad_
);
}
private:
size_t
context_length_
;
int
context_start_
;
size_t
begin_pad_
;
size_t
total_pad_
;
};
REGISTER_TYPED_FUNC
(
ContextProjectionForward
,
CPU
,
ContextProjectionForwardFunc
);
REGISTER_TYPED_FUNC
(
ContextProjectionBackward
,
CPU
,
ContextProjectionBackwardFunc
);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC
(
ContextProjectionForward
,
GPU
,
ContextProjectionForwardFunc
);
REGISTER_TYPED_FUNC
(
ContextProjectionBackwardData
,
GPU
,
ContextProjectionBackwardDataFunc
);
REGISTER_TYPED_FUNC
(
ContextProjectionBackwardWeight
,
GPU
,
ContextProjectionBackwardWeightFunc
);
#endif
}
// namespace paddle
paddle/function/context_projection_op.h
浏览文件 @
590ecba3
...
...
@@ -25,9 +25,10 @@ namespace paddle {
* \param[in] input input data.
* \param[in] weight input weight.
* \param[in] sequence input data.
* \param[in] context_length consecutive rows for concatenation.
* \param[in] begin_pad context start position.
* \param[in] is_padding whether padding 0 or not.
* \param[in] context_length consecutive rows for concatenation.
* \param[in] context_start context start position.
* \param[in] begin_pad begining pad position.
* \param[in] is_padding whether padding 0 or not.
*
*/
template
<
DeviceType
Device
>
...
...
@@ -40,4 +41,43 @@ void ContextProjectionForward(Tensor& output,
size_t
begin_pad
,
bool
is_padding
);
/**
* \brief Context Projection Backward.
*
* \param[out] outputs output gradient.
* \param[in] input input gradient.
* \param[in] weight input weight gradient.
* \param[in] sequence input data.
* \param[in] context_length consecutive rows for concatenation.
* \param[in] context_start context start position.
* \param[in] begin_pad begining pad position.
* \param[in] is_padding whether padding 0 or not.
*
*/
template
<
DeviceType
Device
>
void
ContextProjectionBackward
(
Tensor
&
out_grad
,
const
Tensor
&
in_grad
,
const
Tensor
&
w_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
,
size_t
begin_pad
,
bool
is_padding
);
template
<
DeviceType
Device
>
void
ContextProjectionBackwardData
(
Tensor
&
out_grad
,
Tensor
&
in_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
);
template
<
DeviceType
Device
>
void
ContextProjectionBackwardWeight
(
Tensor
&
out_grad
,
Tensor
&
w_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
,
size_t
total_pad
,
size_t
begin_pad
);
}
// namespace paddle
paddle/function/context_projection_op_gpu.cu
浏览文件 @
590ecba3
...
...
@@ -134,4 +134,214 @@ void ContextProjectionForward<DEVICE_TYPE_GPU>(Tensor& output,
is_padding
);
}
__global__
void
KeContextProjectionBackwardData
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
in_grad
,
int
input_dim
,
int
context_length
,
int
context_start
)
{
int
idx
=
threadIdx
.
x
;
int
block_size
=
blockDim
.
x
;
int
sequenceId
=
blockIdx
.
x
;
int
seq_start
=
sequence
[
sequenceId
];
int
seq_end
=
sequence
[
sequenceId
+
1
];
real
value
=
0
;
int
instances
=
seq_end
-
seq_start
+
context_length
-
1
;
out_grad
+=
seq_start
*
input_dim
*
context_length
;
in_grad
+=
seq_start
*
input_dim
;
for
(
int
k
=
0
;
k
<=
input_dim
/
block_size
;
k
++
)
{
if
(
idx
<
input_dim
)
{
for
(
int
i
=
0
;
i
<
instances
;
i
++
)
{
if
((
i
+
context_start
)
<
0
)
{
continue
;
}
else
if
((
i
+
context_start
)
>=
(
seq_end
-
seq_start
))
{
continue
;
}
else
{
// value = 0;
value
=
in_grad
[(
i
+
context_start
)
*
input_dim
+
idx
];
}
int
outx
=
(
i
-
context_length
)
<
0
?
i
:
(
context_length
-
1
);
int
outy
=
(
i
-
context_length
)
<
0
?
0
:
(
i
-
(
context_length
-
1
));
real
*
output_r
=
out_grad
+
outy
*
input_dim
*
context_length
+
outx
*
input_dim
;
for
(
int
j
=
outy
;
j
<
seq_end
-
seq_start
;
j
++
)
{
value
+=
output_r
[
idx
];
if
(
j
-
outy
==
outx
)
break
;
output_r
+=
(
context_length
-
1
)
*
input_dim
;
}
in_grad
[(
i
+
context_start
)
*
input_dim
+
idx
]
=
value
;
}
}
idx
+=
block_size
;
}
}
void
hl_context_projection_backward_data
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
input_grad
,
int
num_sequences
,
int
input_dim
,
int
context_length
,
int
context_start
)
{
CHECK_NOTNULL
(
out_grad
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
input_grad
);
int
block_size
=
128
;
int
blocks_x
=
num_sequences
;
int
blocks_y
=
1
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
blocks_x
,
blocks_y
);
KeContextProjectionBackwardData
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
out_grad
,
sequence
,
input_grad
,
input_dim
,
context_length
,
context_start
);
CHECK_SYNC
(
"hl_context_projection_backward_data failed"
);
}
template
<
>
void
ContextProjectionBackwardData
<
DEVICE_TYPE_GPU
>
(
Tensor
&
out_grad
,
Tensor
&
in_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
)
{
CHECK
(
in_grad
.
getData
()
&&
out_grad
.
getData
()
&&
sequence
.
getData
());
CHECK_EQ
(
out_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
in_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
sequence
.
dims_
.
size
(),
1
);
CHECK_EQ
(
out_grad
.
dims_
[
1
],
in_grad
.
dims_
[
1
]
*
context_length
);
hl_context_projection_backward_data
(
out_grad
.
getData
(),
reinterpret_cast
<
int
*>
(
sequence
.
getData
()),
in_grad
.
getData
(),
sequence
.
dims_
[
0
]
-
1
,
in_grad
.
dims_
[
1
],
context_length
,
context_start
);
}
template
<
int
THREADS_X
,
int
THREADS_Y
>
__global__
void
KeContextProjectionBackwardWeight
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
w_grad
,
int
num_sequences
,
int
w_dim
,
int
context_length
,
int
context_start
,
int
begin_pad
)
{
__shared__
real
sum_s
[
THREADS_Y
][
THREADS_X
];
int
pad_of_block
=
(
w_dim
+
THREADS_X
-
1
)
/
THREADS_X
;
const
int
idx
=
threadIdx
.
x
;
const
int
idy
=
threadIdx
.
y
;
int
padId
=
blockIdx
.
x
/
pad_of_block
;
int
weight_idx
=
idx
+
THREADS_X
*
(
blockIdx
.
x
%
pad_of_block
);
int
instanceId
;
real
value
=
0
;
real
*
output_r
;
sum_s
[
idy
][
idx
]
=
0.0
f
;
if
(
weight_idx
<
w_dim
)
{
for
(
int
seqId
=
idy
;
seqId
<
num_sequences
;
seqId
+=
THREADS_Y
)
{
int
seq_start
=
sequence
[
seqId
];
int
seq_end
=
sequence
[
seqId
+
1
];
output_r
=
out_grad
+
seq_start
*
w_dim
*
context_length
;
if
(
context_start
<
0
)
{
if
(
padId
+
context_start
<
0
)
{
instanceId
=
padId
;
}
else
{
// begin_pad > 0;
instanceId
=
(
padId
-
begin_pad
)
+
(
seq_end
-
seq_start
)
-
context_start
;
}
}
else
{
if
(
padId
+
(
seq_end
-
seq_start
)
<
context_start
)
{
continue
;
}
else
{
// begin_pad == 0;
instanceId
=
padId
+
(
seq_end
-
seq_start
)
-
context_start
;
}
}
int
outx
=
(
instanceId
-
context_length
)
<
0
?
instanceId
:
(
context_length
-
1
);
int
outy
=
(
instanceId
-
context_length
)
<
0
?
0
:
(
instanceId
-
(
context_length
-
1
));
output_r
+=
outy
*
w_dim
*
context_length
+
outx
*
w_dim
;
for
(
int
j
=
outy
;
j
<
seq_end
-
seq_start
;
j
++
)
{
value
+=
output_r
[
weight_idx
];
if
(
j
-
outy
==
outx
)
break
;
output_r
+=
(
context_length
-
1
)
*
w_dim
;
}
}
sum_s
[
idy
][
idx
]
=
value
;
}
__syncthreads
();
for
(
int
stride
=
THREADS_Y
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
idy
<
stride
)
{
sum_s
[
idy
][
idx
]
+=
sum_s
[
idy
+
stride
][
idx
];
}
__syncthreads
();
}
__syncthreads
();
if
(
weight_idx
<
w_dim
)
{
if
(
idy
==
0
)
{
w_grad
[
padId
*
w_dim
+
weight_idx
]
+=
sum_s
[
0
][
idx
];
}
}
}
void
hl_context_projection_backward_weight
(
real
*
out_grad
,
const
int
*
sequence
,
real
*
w_grad
,
int
num_sequences
,
int
w_dim
,
size_t
total_pad
,
int
context_length
,
int
context_start
,
int
begin_pad
)
{
CHECK_NOTNULL
(
out_grad
);
CHECK_NOTNULL
(
sequence
);
CHECK_NOTNULL
(
w_grad
);
int
threads_x
=
32
;
int
threads_y
=
32
;
int
blocks_x
=
total_pad
*
((
w_dim
+
threads_x
-
1
)
/
threads_x
);
dim3
threads
(
threads_x
,
threads_y
);
dim3
grid
(
blocks_x
,
1
);
KeContextProjectionBackwardWeight
<
32
,
32
>
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
out_grad
,
sequence
,
w_grad
,
num_sequences
,
w_dim
,
context_length
,
context_start
,
begin_pad
);
CHECK_SYNC
(
"hl_context_projection_backward_weight failed"
);
}
template
<
>
void
ContextProjectionBackwardWeight
<
DEVICE_TYPE_GPU
>
(
Tensor
&
out_grad
,
Tensor
&
w_grad
,
const
Tensor
&
sequence
,
size_t
context_length
,
int
context_start
,
size_t
total_pad
,
size_t
begin_pad
)
{
CHECK
(
w_grad
.
getData
()
&&
out_grad
.
getData
());
CHECK_EQ
(
out_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
w_grad
.
dims_
.
size
(),
2
);
CHECK_EQ
(
sequence
.
dims_
.
size
(),
1
);
CHECK_EQ
(
out_grad
.
dims_
[
1
],
w_grad
.
dims_
[
1
]
*
context_length
);
hl_context_projection_backward_weight
(
out_grad
.
getData
(),
reinterpret_cast
<
int
*>
(
sequence
.
getData
()),
w_grad
.
getData
(),
sequence
.
dims_
[
0
]
-
1
,
w_grad
.
dims_
[
1
],
total_pad
,
context_length
,
context_start
,
begin_pad
);
}
}
// namespace paddle
paddle/function/context_projection_op_test.cpp
浏览文件 @
590ecba3
...
...
@@ -77,7 +77,100 @@ void testMatrixProjectionForward(int context_start,
autotest
::
TensorCheckEqual
(
cpu_out
,
gpu_out
);
}
TEST
(
ContextProjectionForward
,
projection
)
{
void
testMatrixProjectionBackward
(
int
context_start
,
int
context_length
,
bool
is_padding
,
size_t
batch_size
,
size_t
input_dim
)
{
size_t
pad
=
std
::
max
(
0
,
-
context_start
)
+
std
::
max
(
0
,
(
int
)(
context_start
+
context_length
-
1
));
if
(
pad
==
0
)
is_padding
=
false
;
std
::
shared_ptr
<
FunctionBase
>
cpu_func
(
FunctionBase
::
funcRegistrar_
.
createByType
(
"ContextProjectionBackward-CPU"
));
FuncConfig
cpu_config
;
cpu_config
.
set
(
"context_length"
,
context_length
)
.
set
(
"context_start"
,
context_start
)
.
set
(
"begin_pad"
,
std
::
max
(
0
,
-
context_start
))
.
set
(
"is_padding"
,
is_padding
);
cpu_func
->
init
(
cpu_config
);
std
::
shared_ptr
<
FunctionBase
>
gpu_data_func
(
FunctionBase
::
funcRegistrar_
.
createByType
(
"ContextProjectionBackwardData-GPU"
));
FuncConfig
gpu_data_config
;
gpu_data_config
.
set
(
"context_length"
,
context_length
)
.
set
(
"context_start"
,
context_start
);
gpu_data_func
->
init
(
gpu_data_config
);
std
::
shared_ptr
<
FunctionBase
>
gpu_w_func
(
FunctionBase
::
funcRegistrar_
.
createByType
(
"ContextProjectionBackwardWeight-GPU"
));
FuncConfig
gpu_w_config
;
gpu_w_config
.
set
(
"context_length"
,
context_length
)
.
set
(
"context_start"
,
context_start
)
.
set
(
"begin_pad"
,
std
::
max
(
0
,
-
context_start
))
.
set
(
"total_pad"
,
pad
);
gpu_w_func
->
init
(
gpu_w_config
);
CpuMatrix
cpu_in_grad
(
batch_size
,
input_dim
);
cpu_in_grad
.
randomizeUniform
();
GpuMatrix
gpu_in_grad
(
batch_size
,
input_dim
);
gpu_in_grad
.
copyFrom
(
cpu_in_grad
);
CpuMatrix
cpu_out_grad
(
batch_size
,
input_dim
*
context_length
);
cpu_out_grad
.
randomizeUniform
();
GpuMatrix
gpu_out_grad
(
batch_size
,
input_dim
*
context_length
);
gpu_out_grad
.
copyFrom
(
cpu_out_grad
);
IVectorPtr
cpu_seq
;
generateSequenceStartPositions
(
batch_size
,
cpu_seq
);
IVectorPtr
gpu_seq
=
IVector
::
create
(
cpu_seq
->
getSize
(),
true
);
gpu_seq
->
copyFrom
(
*
cpu_seq
);
auto
cpu_w_grad
=
is_padding
?
std
::
make_shared
<
CpuMatrix
>
(
pad
,
input_dim
)
:
nullptr
;
auto
gpu_w_grad
=
is_padding
?
std
::
make_shared
<
GpuMatrix
>
(
pad
,
input_dim
)
:
nullptr
;
if
(
is_padding
)
{
cpu_w_grad
->
randomizeUniform
();
gpu_w_grad
->
copyFrom
(
*
cpu_w_grad
);
}
cpu_func
->
calc
({
Tensor
(
cpu_in_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
Tensor
(
cpu_w_grad
?
cpu_w_grad
->
getData
()
:
nullptr
,
Dims
{
pad
,
input_dim
}),
Tensor
(
reinterpret_cast
<
real
*>
(
cpu_seq
->
getData
()),
Dims
{
cpu_seq
->
getSize
()})},
{
Tensor
(
cpu_out_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
{});
gpu_data_func
->
calc
(
{
Tensor
(
gpu_in_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
}),
Tensor
(
reinterpret_cast
<
real
*>
(
gpu_seq
->
getData
()),
Dims
{
gpu_seq
->
getSize
()})},
{
Tensor
(
gpu_out_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
{});
if
(
is_padding
&&
gpu_w_grad
)
{
gpu_w_func
->
calc
({
Tensor
(
gpu_w_grad
->
getData
(),
Dims
{
pad
,
input_dim
}),
Tensor
(
reinterpret_cast
<
real
*>
(
gpu_seq
->
getData
()),
Dims
{
gpu_seq
->
getSize
()})},
{
Tensor
(
gpu_out_grad
.
getData
(),
Dims
{
batch_size
,
input_dim
*
context_length
})},
{});
}
autotest
::
TensorCheckErr
(
cpu_in_grad
,
gpu_in_grad
);
if
(
is_padding
)
{
autotest
::
TensorCheckErr
(
*
cpu_w_grad
,
*
gpu_w_grad
);
}
}
TEST
(
ContextProjection
,
projection
)
{
for
(
auto
context_start
:
{
-
5
,
-
3
,
-
1
,
0
,
3
})
{
for
(
auto
context_length
:
{
1
,
2
,
5
,
7
})
{
for
(
auto
trainable_padding
:
{
false
,
true
})
{
...
...
@@ -93,6 +186,11 @@ TEST(ContextProjectionForward, projection) {
trainable_padding
,
batch_size
,
input_dim
);
testMatrixProjectionBackward
(
context_start
,
context_length
,
trainable_padding
,
batch_size
,
input_dim
);
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录