未验证 提交 cfdd1fc2 编写于 作者: W whs 提交者: GitHub

Fix warpctc in padding mode. (#21033)

上级 8da0cd53
......@@ -37,6 +37,16 @@ inline static size_t MaximumSequenceLength(
return max_seq_len;
}
inline static size_t TotalSequenceLength(
const framework::Vector<size_t>& seq_offset) {
size_t seq_num = seq_offset.size() - 1;
size_t total_seq_len = 0;
for (size_t i = 0; i < seq_num; ++i) {
total_seq_len += seq_offset[i + 1] - seq_offset[i];
}
return total_seq_len;
}
inline static void CheckDims(const framework::DDim& seq_tensor_dims,
const framework::DDim& pad_tensor_dims,
const framework::Vector<size_t>& seq_offset,
......
......@@ -230,8 +230,35 @@ class WarpCTCKernel : public framework::OpKernel<T> {
static_cast<T>(0));
// warpctc accesses labels in CPU memory
Tensor warpctc_label;
LoDTensor warpctc_label;
if (ctx.HasInput("LogitsLength")) {
warpctc_label.mutable_data<int>(
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
platform::CPUPlace());
std::vector<framework::Vector<size_t>> lod;
lod.push_back(label_lod);
warpctc_label.set_lod(lod);
if (platform::is_cpu_place(ctx.GetPlace())) {
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label,
&warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
} else {
LoDTensor gpu_label;
gpu_label.mutable_data<int>(
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
ctx.GetPlace());
gpu_label.set_lod(lod);
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label, &gpu_label,
label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
TensorCopySync(gpu_label, platform::CPUPlace(), &warpctc_label);
}
} else {
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
}
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
......
......@@ -303,7 +303,7 @@ class TestWarpCTCOpWithPadding(OpTest):
self.inputs = {
"Logits": new_logits,
"Label": labels,
"Label": new_labels,
"LogitsLength": self.logits_length,
"LabelLength": self.labels_length
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册