提交 0c4697f8 编写于 作者: C chenweihang

fix: change to enumerate by sentence

上级 4ec12496
......@@ -72,14 +72,14 @@ Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [1, 2, 3, 4, 5]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
)DOC");
......
......@@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ void CalcOutPut(const T* in_data, const int64_t in_len,
const int64_t win_size, const int64_t pad_value,
T* out_data) {
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
const size_t lod_len, const int64_t win_size,
const int64_t pad_value, T* out_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) {
if (index < in_lod[lod_len - 1]) {
int end_idx = 0;
// Get LoD interval of index
for (int i = 1; i < lod_len; ++i) {
if (index < in_lod[i]) {
end_idx = in_lod[i];
break;
}
}
for (size_t i = 0; i < win_size; ++i) {
int word_pos = index + i;
out_data[index * win_size + i] =
word_pos < in_len ? in_data[word_pos] : pad_value;
word_pos < end_idx ? in_data[word_pos] : pad_value;
}
}
}
......@@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
/* Generate enumerate sequence set */
auto stream = context.cuda_device_context().stream();
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
// Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, in_len, win_size, pad_value, out_data);
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
}
};
......
......@@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
"The actual input data's size mismatched with LoD information.");
// Generate enumerate sequence set
auto seq_length = in_dims[0];
auto lod0 = in_lod[0];
auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace());
for (int idx = 0; idx < seq_length; ++idx) {
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
int word_pos = idx + word_idx;
out_data[win_size * idx + word_idx] =
word_pos < seq_length ? in_data[word_pos] : pad_value;
for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
size_t word_pos = idx + word_idx;
out_data[win_size * idx + word_idx] =
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
}
}
}
}
......
......@@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [1, 2, 3, 4, 5]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
Args:
......@@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
attrs={'win_size': win_size,
'pad_value': pad_value})
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
**SequenceMask Layer**
......
......@@ -19,16 +19,20 @@ import numpy as np
from op_test import OpTest
def sequence_enumerate(input_seq, win_size, pad_value):
def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
lod0 = [0]
for i in range(0, len(in_lod[0])):
lod0.append(lod0[i] + in_lod[0][i])
out_seq = []
for idx in range(0, len(input_seq)):
single_seq = []
for word_idx in range(win_size):
word_pos = idx + word_idx
dat = input_seq[word_pos] if word_pos < len(input_seq) \
for i in range(0, len(lod0) - 1):
for idx in range(lod0[i], lod0[i + 1]):
single_seq = []
for word_idx in range(win_size):
word_pos = idx + word_idx
dat = input_seq[word_pos] if word_pos < lod0[i+1] \
else pad_value
single_seq.append(dat)
out_seq.append(single_seq)
single_seq.append(dat)
out_seq.append(single_seq)
return out_seq
......@@ -48,7 +52,8 @@ class TestSequenceEnumerateOp(OpTest):
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
......@@ -58,7 +63,8 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int64")
......@@ -68,7 +74,8 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]]
self.win_size = 30
self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册