Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
3c749fae
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3c749fae
编写于
8月 16, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update CPU sequence_padding functor
上级
a38a8db9
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
108 addition
and
113 deletion
+108
-113
paddle/fluid/operators/math/sequence_padding.cc
paddle/fluid/operators/math/sequence_padding.cc
+78
-71
paddle/fluid/operators/math/sequence_padding.h
paddle/fluid/operators/math/sequence_padding.h
+22
-34
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+3
-3
paddle/fluid/operators/warpctc_op.h
paddle/fluid/operators/warpctc_op.h
+5
-5
未找到文件。
paddle/fluid/operators/math/sequence_padding.cc
浏览文件 @
3c749fae
...
@@ -18,37 +18,45 @@ namespace paddle {
...
@@ -18,37 +18,45 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
template
<
typename
T
>
enum
CopyType
{
kSeqToPad
,
kPadToSeq
};
void
CopyDataCPU
(
framework
::
LoDTensor
*
seq_tensor
,
framework
::
Tensor
*
pad_tensor
,
const
framework
::
Vector
<
size_t
>&
seq_offset
,
const
int64_t
&
max_seq_len
,
const
int64_t
&
seq_width
,
bool
seq_to_pad
,
bool
norm_by_len
,
OutputLayout
output_layout
)
{
T
*
seq_data
=
seq_tensor
->
data
<
T
>
();
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
int64_t
seq_num
=
seq_offset
.
size
()
-
1
;
template
<
typename
T
>
void
CopyValidData
(
framework
::
Tensor
*
dst_tensor
,
for
(
int64_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
const
framework
::
Tensor
*
src_tensor
,
int64_t
seq_start
=
seq_offset
[
i
];
const
framework
::
Vector
<
size_t
>&
seq_offsets
,
int64_t
seq_len
=
seq_offset
[
i
+
1
]
-
seq_start
;
int
pad_seq_len
,
int
step_width
,
bool
norm_by_len
,
T
scale
=
norm_by_len
?
(
1.0
f
/
static_cast
<
T
>
(
seq_len
))
:
1.0
f
;
CopyType
type
,
PadLayout
layout
)
{
for
(
int64_t
j
=
0
;
j
<
seq_len
;
++
j
)
{
int
seq_num
=
seq_offsets
.
size
()
-
1
;
for
(
int64_t
k
=
0
;
k
<
seq_width
;
++
k
)
{
const
T
*
src_data
=
src_tensor
->
data
<
T
>
();
size_t
pad_data_idx
=
0
;
T
*
dst_data
=
dst_tensor
->
data
<
T
>
();
size_t
seq_data_idx
=
(
seq_start
+
j
)
*
seq_width
+
k
;
if
(
output_layout
==
kBatchLengthWidth
)
{
int
seq_cpy_gap
=
step_width
;
pad_data_idx
=
(
i
*
max_seq_len
+
j
)
*
seq_width
+
k
;
int
pad_cpy_gap
=
}
else
{
layout
==
kBatchLengthWidth
?
step_width
:
seq_num
*
step_width
;
pad_data_idx
=
(
j
*
seq_num
+
i
)
*
seq_width
+
k
;
for
(
int
seq_idx
=
0
;
seq_idx
<
seq_num
;
++
seq_idx
)
{
}
int
valid_seq_len
=
seq_offsets
[
seq_idx
+
1
]
-
seq_offsets
[
seq_idx
];
if
(
seq_to_pad
)
{
PADDLE_ENFORCE_GE
(
pad_data
[
pad_data_idx
]
=
seq_data
[
seq_data_idx
]
*
scale
;
pad_seq_len
,
valid_seq_len
,
}
else
{
"The padded sequence length can not be less than its original length."
);
seq_data
[
seq_data_idx
]
=
pad_data
[
pad_data_idx
]
*
scale
;
int
seq_data_offset
=
seq_offsets
[
seq_idx
]
*
step_width
;
int
pad_data_offset
=
layout
==
kBatchLengthWidth
?
seq_idx
*
pad_seq_len
*
step_width
:
seq_idx
*
step_width
;
float
scale
=
1.0
f
/
static_cast
<
float
>
(
valid_seq_len
);
for
(
int
step_idx
=
0
;
step_idx
<
valid_seq_len
;
++
step_idx
)
{
const
T
*
src
=
src_data
+
(
type
==
kSeqToPad
?
seq_data_offset
:
pad_data_offset
);
T
*
dst
=
dst_data
+
(
type
==
kSeqToPad
?
pad_data_offset
:
seq_data_offset
);
memcpy
(
dst
,
src
,
step_width
*
sizeof
(
T
));
if
(
norm_by_len
)
{
for
(
int
i
=
0
;
i
<
step_width
;
++
i
)
{
*
(
dst
+
i
)
*=
scale
;
}
}
}
}
seq_data_offset
+=
seq_cpy_gap
;
pad_data_offset
+=
pad_cpy_gap
;
}
}
}
}
}
}
...
@@ -58,31 +66,37 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
...
@@ -58,31 +66,37 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
Tensor
*
pad_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
T
pad_value
=
static_cast
<
T
>
(
0
),
bool
norm_by_times
=
false
,
std
::
vector
<
T
>
pad_value
=
{
0
},
int
pad_seq_len
=
-
1
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_layout
=
kBatchLengthWidth
)
{
const
PadLayout
layout
=
kBatchLengthWidth
)
{
CheckLoD
(
seq_tensor
,
lod_level
);
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
.
lod
())[
lod_level
];
auto
&
lod
=
seq_tensor
.
lod
();
auto
&
seq_offset
=
framework
::
ToAbsOffset
(
lod
)[
lod_level
];
auto
seq_tensor_dims
=
seq_tensor
.
dims
();
auto
seq_tensor_dims
=
seq_tensor
.
dims
();
auto
pad_tensor_dims
=
pad_tensor
->
dims
();
auto
pad_tensor_dims
=
pad_tensor
->
dims
();
int64_t
max_seq_len
=
MaximumSequenceLength
(
seq_offset
);
if
(
pad_seq_len
==
-
1
)
{
int64_t
seq_num
=
seq_offset
.
size
()
-
1
;
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
int64_t
seq_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
}
int
step_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
CheckDims
(
seq_tensor_dims
,
seq_offset
.
back
(),
pad_tensor_dims
,
max_seq_len
,
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
seq_num
,
seq_width
,
output_layout
);
step_width
,
layout
);
PADDLE_ENFORCE
(
pad_value
.
size
()
==
1
||
static_cast
<
int
>
(
pad_value
.
size
())
==
step_width
,
"The size of 'pad_value' can only be 1 or be equal to the "
"'step_width'."
);
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
if
(
pad_value
.
size
()
==
1
)
{
pad_value
=
std
::
vector
<
T
>
(
step_width
,
pad_value
[
0
]);
}
memset
(
pad_data
,
pad_value
,
max_seq_len
*
seq_num
*
seq_width
*
sizeof
(
T
));
// fill padding value
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
()
/
step_width
;
++
i
)
{
memcpy
(
pad_data
,
pad_value
.
data
(),
step_width
*
sizeof
(
T
));
}
CopyDataCPU
<
T
>
(
const_cast
<
framework
::
LoDTensor
*>
(
&
seq_tensor
),
pad_tensor
,
CopyValidData
<
T
>
(
pad_tensor
,
&
seq_tensor
,
seq_offsets
,
pad_seq_len
,
seq_offset
,
max_seq_len
,
seq_width
,
true
/* seq_to_pad */
,
step_width
,
norm_by_times
,
kSeqToPad
,
layout
);
norm_by_times
,
output_layout
);
}
}
};
};
...
@@ -90,30 +104,23 @@ template <typename T>
...
@@ -90,30 +104,23 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
framework
::
LoDTensor
*
seq_tensor
,
const
framework
::
LoDTensor
&
pad_tensor
,
const
framework
::
Tensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
bool
norm_by_times
=
false
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_layout
=
kBatchLengthWidth
)
{
const
PadLayout
&
layout
=
kBatchLengthWidth
)
{
CheckLoD
(
*
seq_tensor
,
lod_level
);
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
auto
seq_tensor_dims
=
seq_tensor
->
dims
();
auto
&
lod
=
seq_tensor
->
lod
();
auto
pad_tensor_dims
=
pad_tensor
.
dims
();
auto
&
seq_offset
=
framework
::
ToAbsOffset
(
lod
)[
lod_level
];
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
}
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
int
step_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
int64_t
max_seq_len
=
MaximumSequenceLength
(
seq_offset
);
int64_t
seq_num
=
seq_offset
.
size
()
-
1
;
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
int64_t
seq_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
step_width
,
layout
);
CheckDims
(
seq_tensor_dims
,
seq_offset
.
back
(),
pad_tensor_dims
,
max_seq_len
,
CopyValidData
<
T
>
(
seq_tensor
,
&
pad_tensor
,
seq_offsets
,
pad_seq_len
,
seq_num
,
seq_width
,
output_layout
);
step_width
,
norm_by_times
,
kPadToSeq
,
layout
);
T
*
seq_data
=
seq_tensor
->
data
<
T
>
();
memset
(
seq_data
,
static_cast
<
T
>
(
0
),
seq_tensor
->
numel
()
*
sizeof
(
T
));
CopyDataCPU
<
T
>
(
seq_tensor
,
const_cast
<
framework
::
Tensor
*>
(
&
pad_tensor
),
seq_offset
,
max_seq_len
,
seq_width
,
false
/* seq_to_pad */
,
norm_by_times
,
output_layout
);
}
}
};
};
...
...
paddle/fluid/operators/math/sequence_padding.h
浏览文件 @
3c749fae
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
...
@@ -22,7 +23,7 @@ namespace paddle {
...
@@ -22,7 +23,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
enum
Output
Layout
{
kBatchLengthWidth
=
0
,
kLengthBatchWidth
};
enum
Pad
Layout
{
kBatchLengthWidth
=
0
,
kLengthBatchWidth
};
inline
static
size_t
MaximumSequenceLength
(
inline
static
size_t
MaximumSequenceLength
(
const
framework
::
Vector
<
size_t
>&
seq_offset
)
{
const
framework
::
Vector
<
size_t
>&
seq_offset
)
{
...
@@ -34,35 +35,22 @@ inline static size_t MaximumSequenceLength(
...
@@ -34,35 +35,22 @@ inline static size_t MaximumSequenceLength(
return
max_seq_len
;
return
max_seq_len
;
}
}
inline
static
void
CheckLoD
(
const
framework
::
LoDTensor
&
seq_tensor
,
const
size_t
&
lod_level
)
{
PADDLE_ENFORCE
(
lod_level
<
seq_tensor
.
lod
().
size
(),
"Invalid lod level which should be at least 0 and less "
"than maximum lod level of sequence tensor."
);
}
inline
static
void
CheckDims
(
const
framework
::
DDim
&
seq_tensor_dims
,
inline
static
void
CheckDims
(
const
framework
::
DDim
&
seq_tensor_dims
,
const
size_t
&
last_offset
,
const
framework
::
DDim
&
pad_tensor_dims
,
const
framework
::
DDim
&
pad_tensor_dims
,
const
int64_t
&
max_seq_len
,
const
int64_t
&
seq_num
,
const
framework
::
Vector
<
size_t
>&
seq_offset
,
const
int64_t
&
seq
_width
,
int64_t
padded_seq_len
,
int64_t
step
_width
,
const
OutputLayout
&
output_
layout
)
{
const
PadLayout
&
layout
)
{
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
seq_tensor_dims
[
0
]),
last_offset
,
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
seq_tensor_dims
[
0
]),
seq_offset
.
back
()
,
"Value of 1st dimension of the sequence tensor should be "
"Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences."
);
"equal to sum of lengths of all sequences."
);
PADDLE_ENFORCE
_EQ
(
pad_tensor_dims
.
size
(),
3UL
,
PADDLE_ENFORCE
(
seq_tensor_dims
.
size
()
==
1
||
seq_tensor_dims
.
size
()
==
2
,
"Padded tensor should be a 3-D tensor
."
);
"seq_tensor's rank should be 1 or 2
."
);
if
(
output_layout
==
kBatchLengthWidth
)
{
PADDLE_ENFORCE
(
seq_tensor_dims
.
size
()
+
1
==
pad_tensor_dims
.
size
()
||
PADDLE_ENFORCE_EQ
(
pad_tensor_dims
,
seq_tensor_dims
.
size
()
==
pad_tensor_dims
.
size
(),
framework
::
make_ddim
({
seq_num
,
max_seq_len
,
seq_width
}));
"pad_tensor's rank should be 1 greater than seq_tensor's "
}
else
if
(
output_layout
==
kLengthBatchWidth
)
{
"rank, or be equal with it."
);
PADDLE_ENFORCE_EQ
(
pad_tensor_dims
,
framework
::
make_ddim
({
max_seq_len
,
seq_num
,
seq_width
}));
}
else
{
PADDLE_THROW
(
"Unsupported output layout."
);
}
}
}
/*
/*
...
@@ -94,22 +82,22 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
...
@@ -94,22 +82,22 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
PaddingLoDTensorFunctor
{
class
PaddingLoDTensorFunctor
{
public:
public:
void
operator
()(
const
DeviceContext
&
context
,
void
operator
()(
const
platform
::
CPU
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
Tensor
*
pad_tensor
,
framework
::
LoD
Tensor
*
pad_tensor
,
T
pad_value
=
static_cast
<
T
>
(
0
),
bool
norm_by_times
=
false
,
std
::
vector
<
T
>
pad_value
=
{
0
},
int
pad_seq_len
=
-
1
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_
layout
=
kBatchLengthWidth
);
const
PadLayout
layout
=
kBatchLengthWidth
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
UnpaddingLoDTensorFunctor
{
class
UnpaddingLoDTensorFunctor
{
public:
public:
void
operator
()(
const
DeviceContext
&
context
,
void
operator
()(
const
platform
::
CPU
DeviceContext
&
context
,
framework
::
LoDTensor
*
seq
_tensor
,
const
framework
::
LoDTensor
&
pad
_tensor
,
const
framework
::
Tensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
bool
norm_by_times
=
false
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_
layout
=
kBatchLengthWidth
);
const
PadLayout
&
layout
=
kBatchLengthWidth
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
3c749fae
...
@@ -23,7 +23,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
...
@@ -23,7 +23,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
paddle
::
framework
::
LoDTensor
cpu_seq_back
;
paddle
::
framework
::
LoDTensor
cpu_seq_back
;
paddle
::
framework
::
LoDTensor
seq
;
paddle
::
framework
::
LoDTensor
seq
;
paddle
::
framework
::
LoDTensor
seq_back
;
paddle
::
framework
::
LoDTensor
seq_back
;
paddle
::
framework
::
Tensor
padding
;
paddle
::
framework
::
LoD
Tensor
padding
;
const
size_t
level
=
lod
.
size
()
-
1
;
const
size_t
level
=
lod
.
size
()
-
1
;
auto
seq_dims
=
auto
seq_dims
=
...
@@ -56,13 +56,13 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
...
@@ -56,13 +56,13 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
seq
,
&
padding
,
0
,
false
,
0
,
*
context
,
seq
,
&
padding
,
{
0
},
-
1
,
0
,
false
,
paddle
::
operators
::
math
::
kLengthBatchWidth
);
paddle
::
operators
::
math
::
kLengthBatchWidth
);
seq_back
.
set_lod
(
lod
);
seq_back
.
set_lod
(
lod
);
seq_back
.
mutable_data
<
T
>
(
seq_dims
,
*
place
);
seq_back
.
mutable_data
<
T
>
(
seq_dims
,
*
place
);
paddle
::
operators
::
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
paddle
::
operators
::
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
&
seq_back
,
padding
,
false
,
0
,
*
context
,
padding
,
&
seq_back
,
-
1
,
0
,
false
,
paddle
::
operators
::
math
::
kLengthBatchWidth
);
paddle
::
operators
::
math
::
kLengthBatchWidth
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
...
...
paddle/fluid/operators/warpctc_op.h
浏览文件 @
3c749fae
...
@@ -153,7 +153,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
...
@@ -153,7 +153,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
num_sequences
),
1
});
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
num_sequences
),
1
});
// warpctc needs sequences data stored in transposed padding format
// warpctc needs sequences data stored in transposed padding format
Tensor
warpctc_logits
;
LoD
Tensor
warpctc_logits
;
const
size_t
max_sequence_length
=
const
size_t
max_sequence_length
=
math
::
MaximumSequenceLength
(
logits_lod
[
level
]);
math
::
MaximumSequenceLength
(
logits_lod
[
level
]);
auto
warpctc_logits_dims
=
auto
warpctc_logits_dims
=
...
@@ -163,7 +163,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
...
@@ -163,7 +163,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
static_cast
<
T
>
(
0
),
false
/* norm_by_times */
,
0
,
{
static_cast
<
T
>
(
0
)},
-
1
,
0
,
false
/* norm_by_times */
,
math
::
kLengthBatchWidth
);
math
::
kLengthBatchWidth
);
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
...
@@ -210,15 +210,15 @@ template <typename DeviceContext, typename T>
...
@@ -210,15 +210,15 @@ template <typename DeviceContext, typename T>
class
WarpCTCGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
WarpCTCGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
warpctc_grad
=
ctx
.
Input
<
Tensor
>
(
"WarpCTCGrad"
);
auto
*
warpctc_grad
=
ctx
.
Input
<
LoD
Tensor
>
(
"WarpCTCGrad"
);
auto
*
logits_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
));
auto
*
logits_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
loss_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
const
Tensor
*
loss_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
logits_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
logits_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
norm_by_times
=
ctx
.
Attr
<
bool
>
(
"norm_by_times"
);
bool
norm_by_times
=
ctx
.
Attr
<
bool
>
(
"norm_by_times"
);
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
logits
_grad
,
ctx
.
template
device_context
<
DeviceContext
>(),
*
warpctc
_grad
,
*
warpctc_grad
,
norm_by_times
,
0
,
math
::
kLengthBatchWidth
);
logits_grad
,
-
1
,
0
,
norm_by_times
,
math
::
kLengthBatchWidth
);
const
T
*
loss_grad_data
=
loss_grad
->
data
<
T
>
();
const
T
*
loss_grad_data
=
loss_grad
->
data
<
T
>
();
math
::
ScaleLoDTensorFunctor
<
DeviceContext
,
T
>
()(
math
::
ScaleLoDTensorFunctor
<
DeviceContext
,
T
>
()(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录