提交 89de5d5e 编写于 作者: W wanghaoshuang

Fix cuda kernel of sequence scale functor

上级 9eb3fb29
......@@ -22,16 +22,16 @@ template <typename T>
__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales,
const size_t num_seq,
const size_t seq_width) {
size_t idx = blockIdx.x * blockDim.y + threadIdx.y;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < lod[num_seq]) {
if (idx < lod[num_seq] * seq_width) {
size_t i = 0;
for (i = 0; i < num_seq; ++i) {
if (idx < lod[i + 1] * seq_width) {
break;
}
}
seq[i] *= scales[i];
seq[idx] *= scales[i];
}
}
......
......@@ -27,19 +27,21 @@ namespace math {
* All sequences will be padded to the same length and stored in a transposed
* shape.
* Example:
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
* Given:
* seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* scales = (2, 3, 4, 5)
* then:
* result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3)
*
* \param context device context of this functor.
* \param context Device context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences'
* length.
* \param padding Tensor which is padded to the same length, the shape is
* [max_sequence_length, num_sequences, sequence_width].
* \param norm_by_times whether dividing sequence's length.
* \param scales Array<T>. The i-th sequence will be scaled by scales[i].
* \param num_seq Number of sequence
*
* \note transposition is also done in this functor.
*/
template <typename DeviceContext, typename T>
class ScaleLoDTensorFunctor {
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h"
......@@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
// LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims;
// for (int i=0; i<loss_grad->numel();i++) {
// LOG(ERROR) << "loss_grad: " << loss_grad_data[i];
//}
// T* logits_grad_data =
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
......@@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
loss_grad_data, num_seq);
/*
int level = 0;
auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod());
const size_t num_sequences = logits_grad_lod[level].size() - 1;
for (int seq_index = 0; seq_index < num_sequences; ++seq_index) {
for (int token_index = logits_grad_lod[level][seq_index];
token_index < logits_grad_lod[level][seq_index + 1];
++token_index) {
logits_grad_data[token_index] *= loss_grad_data[seq_index];
}
}
*/
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册