提交 8f37c3c2 编写于 作者: W wanghaoshuang

Fix sequence scale functor cuda kernel

1. Fix kernel
2. Add more test case
上级 45cf2341
......@@ -13,16 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_scale.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, int BlockSize>
__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales,
const size_t seq_width) {
if (threadIdx.x < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width) {
int idx = lod[blockIdx.x] * seq_width + threadIdx.x;
for (int i = threadIdx.x;
i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width;
i += BlockSize) {
int idx = lod[blockIdx.x] * seq_width + i;
seq[idx] *= scales[blockIdx.x];
}
}
......@@ -39,8 +44,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq.mutable_data<T>(context.GetPlace());
int threads = 1024;
SequenceScaleKernel<T><<<num_seq, threads, 0, context.stream()>>>(
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].data(), scales, seq_width);
}
};
......
......@@ -4,6 +4,8 @@ import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
CUDA_BLOCK_SIZE = 512
class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
......@@ -154,39 +156,45 @@ class CTCForward(object):
class TestWarpCTCOp(OpTest):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[0, 4, 5, 8, 11]]
self.labels_lod = [[0, 3, 4, 8, 12]]
self.blank = self.num_classes - 1
self.norm_by_times = False
def setUp(self):
self.op_type = "warpctc"
self.config()
batch_size = 4
num_classes = 8
logits_lod = [[0, 4, 5, 8, 11]]
logits = np.random.uniform(0.1, 1.0,
[11, num_classes]).astype("float32")
logits = np.random.uniform(
0.1, 1.0,
[self.logits_lod[0][-1], self.num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels_lod = [[0, 3, 4, 8, 12]]
# labels should not be blank
labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32")
blank = num_classes - 1
norm_by_times = False
labels = np.random.randint(
0, self.num_classes - 1, [self.labels_lod[0][-1], 1], dtype="int32")
ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank,
norm_by_times)
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(batch_size):
max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i])
for i in range(self.batch_size):
max_sequence_length = max(
max_sequence_length,
self.logits_lod[0][i + 1] - self.logits_lod[0][i])
self.gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32")
[max_sequence_length, self.batch_size, self.num_classes],
dtype="float32")
self.inputs = {
"Logits": (logits, logits_lod),
"Label": (labels, labels_lod)
"Logits": (logits, self.logits_lod),
"Label": (labels, self.labels_lod)
}
self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times}
def test_check_output(self):
self.check_output()
......@@ -196,5 +204,15 @@ class TestWarpCTCOp(OpTest):
self.check_grad(["Logits"], "Loss", max_relative_error=0.007)
class TestWarpCTCOpCase1(TestWarpCTCOp):
def config(self):
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_lod = [[0, 4, 5, 8, 11]]
self.labels_lod = [[0, 3, 4, 8, 12]]
self.blank = 0
self.norm_by_times = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册