未验证 提交 4c7d196d 编写于 作者: W whs 提交者: GitHub

Add norm_by_time for warpctc op in padding mode. (#17580)

上级 e89c16b9
......@@ -267,6 +267,11 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
size_t num_sequences = warpctc_grad->dims()[1];
size_t seq_width = warpctc_grad->dims()[2];
auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
framework::Tensor logits_length_cpu;
framework::TensorCopy(*logits_length, platform::CPUPlace(),
&logits_length_cpu);
LoDTensor logits_grad_with_lod;
auto logits_grad_dims =
framework::make_ddim({static_cast<int64_t>(max_seq_length),
......@@ -289,10 +294,14 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
const T* loss_grad_data = loss_grad_cpu.data<T>();
for (size_t i = 0; i < max_seq_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
T scale = 1.0;
if (norm_by_times) {
scale = 1.0 / static_cast<T>(logits_length_cpu.data<int64_t>()[j]);
}
for (size_t k = 0; k < seq_width; ++k) {
size_t idx = i * (num_sequences * seq_width) + j * seq_width + k;
scaled_logits_data[idx] =
logits_grad_cpu_data[idx] * loss_grad_data[j];
logits_grad_cpu_data[idx] * loss_grad_data[j] * scale;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册