提交 89de5d5e 编写于 作者: W wanghaoshuang

Fix cuda kernel of sequence scale functor

上级 9eb3fb29
...@@ -22,16 +22,16 @@ template <typename T> ...@@ -22,16 +22,16 @@ template <typename T>
__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales, __global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales,
const size_t num_seq, const size_t num_seq,
const size_t seq_width) { const size_t seq_width) {
size_t idx = blockIdx.x * blockDim.y + threadIdx.y; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < lod[num_seq]) { if (idx < lod[num_seq] * seq_width) {
size_t i = 0; size_t i = 0;
for (i = 0; i < num_seq; ++i) { for (i = 0; i < num_seq; ++i) {
if (idx < lod[i + 1] * seq_width) { if (idx < lod[i + 1] * seq_width) {
break; break;
} }
} }
seq[i] *= scales[i]; seq[idx] *= scales[i];
} }
} }
......
...@@ -27,19 +27,21 @@ namespace math { ...@@ -27,19 +27,21 @@ namespace math {
* All sequences will be padded to the same length and stored in a transposed * All sequences will be padded to the same length and stored in a transposed
* shape. * shape.
* Example: * Example:
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3) * Given:
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0) * seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* scales = (2, 3, 4, 5)
* then:
* result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3)
* *
* \param context device context of this functor. * \param context Device context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape * \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where * is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences' * total_sequence_length is the sum of all sequences'
* length. * length.
* \param padding Tensor which is padded to the same length, the shape is * \param scales Array<T>. The i-th sequence will be scaled by scales[i].
* [max_sequence_length, num_sequences, sequence_width]. * \param num_seq Number of sequence
* \param norm_by_times whether dividing sequence's length.
* *
* \note transposition is also done in this functor.
*/ */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ScaleLoDTensorFunctor { class ScaleLoDTensorFunctor {
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h" #include "paddle/operators/math/sequence_padding.h"
...@@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> { ...@@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits")); auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss")); const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
// LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims;
// for (int i=0; i<loss_grad->numel();i++) {
// LOG(ERROR) << "loss_grad: " << loss_grad_data[i];
//}
// T* logits_grad_data =
logits_grad->mutable_data<T>(ctx.GetPlace()); logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times"); bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()( math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
...@@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> { ...@@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
math::ScaleLoDTensorFunctor<DeviceContext, T>()( math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad, ctx.template device_context<DeviceContext>(), *logits_grad,
loss_grad_data, num_seq); loss_grad_data, num_seq);
/*
int level = 0;
auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod());
const size_t num_sequences = logits_grad_lod[level].size() - 1;
for (int seq_index = 0; seq_index < num_sequences; ++seq_index) {
for (int token_index = logits_grad_lod[level][seq_index];
token_index < logits_grad_lod[level][seq_index + 1];
++token_index) {
logits_grad_data[token_index] *= loss_grad_data[seq_index];
}
}
*/
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册