Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
89de5d5e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
89de5d5e
编写于
1月 11, 2018
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix cuda kernel of sequence scale functor
上级
9eb3fb29
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
29 deletion
+12
-29
paddle/operators/math/sequence_scale.cu
paddle/operators/math/sequence_scale.cu
+3
-3
paddle/operators/math/sequence_scale.h
paddle/operators/math/sequence_scale.h
+9
-7
paddle/operators/warpctc_op.h
paddle/operators/warpctc_op.h
+0
-19
未找到文件。
paddle/operators/math/sequence_scale.cu
浏览文件 @
89de5d5e
...
...
@@ -22,16 +22,16 @@ template <typename T>
__global__
void
SequenceScaleKernel
(
T
*
seq
,
size_t
*
lod
,
const
T
*
scales
,
const
size_t
num_seq
,
const
size_t
seq_width
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
lod
[
num_seq
])
{
if
(
idx
<
lod
[
num_seq
]
*
seq_width
)
{
size_t
i
=
0
;
for
(
i
=
0
;
i
<
num_seq
;
++
i
)
{
if
(
idx
<
lod
[
i
+
1
]
*
seq_width
)
{
break
;
}
}
seq
[
i
]
*=
scales
[
i
];
seq
[
i
dx
]
*=
scales
[
i
];
}
}
...
...
paddle/operators/math/sequence_scale.h
浏览文件 @
89de5d5e
...
...
@@ -27,19 +27,21 @@ namespace math {
* All sequences will be padded to the same length and stored in a transposed
* shape.
* Example:
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
* Given:
* seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* scales = (2, 3, 4, 5)
* then:
* result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3)
*
* \param context
d
evice context of this functor.
* \param context
D
evice context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences'
* length.
* \param padding Tensor which is padded to the same length, the shape is
* [max_sequence_length, num_sequences, sequence_width].
* \param norm_by_times whether dividing sequence's length.
* \param scales Array<T>. The i-th sequence will be scaled by scales[i].
* \param num_seq Number of sequence
*
* \note transposition is also done in this functor.
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ScaleLoDTensorFunctor
{
...
...
paddle/operators/warpctc_op.h
浏览文件 @
89de5d5e
...
...
@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h"
...
...
@@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
auto
*
logits_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
loss_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
// LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims;
// for (int i=0; i<loss_grad->numel();i++) {
// LOG(ERROR) << "loss_grad: " << loss_grad_data[i];
//}
// T* logits_grad_data =
logits_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
norm_by_times
=
ctx
.
Attr
<
bool
>
(
"norm_by_times"
);
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
...
...
@@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
math
::
ScaleLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits_grad
,
loss_grad_data
,
num_seq
);
/*
int level = 0;
auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod());
const size_t num_sequences = logits_grad_lod[level].size() - 1;
for (int seq_index = 0; seq_index < num_sequences; ++seq_index) {
for (int token_index = logits_grad_lod[level][seq_index];
token_index < logits_grad_lod[level][seq_index + 1];
++token_index) {
logits_grad_data[token_index] *= loss_grad_data[seq_index];
}
}
*/
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录