提交 0c4697f8 编写于 作者: C chenweihang

fix: change to enumerate by sentence

上级 4ec12496
...@@ -72,14 +72,14 @@ Examples: ...@@ -72,14 +72,14 @@ Examples:
Case 1: Case 1:
Input: Input:
X.lod = [[0, 3, 5]] X.lod = [[0, 3, 5]]
X.data = [1, 2, 3, 4, 5] X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1] X.dims = [5, 1]
Attrs: Attrs:
win_size = 2 win_size = 2
pad_value = 0 pad_value = 0
Output: Output:
Out.lod = [[0, 3, 5]] Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2] Out.dims = [5, 2]
)DOC"); )DOC");
......
...@@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS; ...@@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
__global__ void CalcOutPut(const T* in_data, const int64_t in_len, __global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
const int64_t win_size, const int64_t pad_value, const size_t lod_len, const int64_t win_size,
T* out_data) { const int64_t pad_value, T* out_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_len) { if (index < in_lod[lod_len - 1]) {
int end_idx = 0;
// Get LoD interval of index
for (int i = 1; i < lod_len; ++i) {
if (index < in_lod[i]) {
end_idx = in_lod[i];
break;
}
}
for (size_t i = 0; i < win_size; ++i) { for (size_t i = 0; i < win_size; ++i) {
int word_pos = index + i; int word_pos = index + i;
out_data[index * win_size + i] = out_data[index * win_size + i] =
word_pos < in_len ? in_data[word_pos] : pad_value; word_pos < end_idx ? in_data[word_pos] : pad_value;
} }
} }
} }
...@@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
/* Generate enumerate sequence set */ /* Generate enumerate sequence set */
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
auto lod0 = in_lod[0];
auto in_len = in->numel(); auto in_len = in->numel();
auto in_data = in->data<T>(); auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
// Calc output tensor // Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, in_len, win_size, pad_value, out_data); in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
} }
}; };
......
...@@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> { ...@@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
"The actual input data's size mismatched with LoD information."); "The actual input data's size mismatched with LoD information.");
// Generate enumerate sequence set // Generate enumerate sequence set
auto seq_length = in_dims[0]; auto lod0 = in_lod[0];
auto in_data = in->data<T>(); auto in_data = in->data<T>();
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data = out->mutable_data<T>(context.GetPlace());
for (int idx = 0; idx < seq_length; ++idx) { for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (int word_idx = 0; word_idx < win_size; ++word_idx) { for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
int word_pos = idx + word_idx; for (int word_idx = 0; word_idx < win_size; ++word_idx) {
out_data[win_size * idx + word_idx] = size_t word_pos = idx + word_idx;
word_pos < seq_length ? in_data[word_pos] : pad_value; out_data[win_size * idx + word_idx] =
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
}
} }
} }
} }
......
...@@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None): ...@@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
Case 1: Case 1:
Input: Input:
X.lod = [[0, 3, 5]] X.lod = [[0, 3, 5]]
X.data = [1, 2, 3, 4, 5] X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1] X.dims = [5, 1]
Attrs: Attrs:
win_size = 2 win_size = 2
pad_value = 0 pad_value = 0
Output: Output:
Out.lod = [[0, 3, 5]] Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2] Out.dims = [5, 2]
Args: Args:
...@@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): ...@@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
attrs={'win_size': win_size, attrs={'win_size': win_size,
'pad_value': pad_value}) 'pad_value': pad_value})
def sequence_mask(x, maxlen=None, dtype='int64', name=None): def sequence_mask(x, maxlen=None, dtype='int64', name=None):
""" """
**SequenceMask Layer** **SequenceMask Layer**
......
...@@ -19,16 +19,20 @@ import numpy as np ...@@ -19,16 +19,20 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def sequence_enumerate(input_seq, win_size, pad_value): def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
lod0 = [0]
for i in range(0, len(in_lod[0])):
lod0.append(lod0[i] + in_lod[0][i])
out_seq = [] out_seq = []
for idx in range(0, len(input_seq)): for i in range(0, len(lod0) - 1):
single_seq = [] for idx in range(lod0[i], lod0[i + 1]):
for word_idx in range(win_size): single_seq = []
word_pos = idx + word_idx for word_idx in range(win_size):
dat = input_seq[word_pos] if word_pos < len(input_seq) \ word_pos = idx + word_idx
dat = input_seq[word_pos] if word_pos < lod0[i+1] \
else pad_value else pad_value
single_seq.append(dat) single_seq.append(dat)
out_seq.append(single_seq) out_seq.append(single_seq)
return out_seq return out_seq
...@@ -48,7 +52,8 @@ class TestSequenceEnumerateOp(OpTest): ...@@ -48,7 +52,8 @@ class TestSequenceEnumerateOp(OpTest):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 2 self.win_size = 2
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32") self.out_seq = np.array(out_seq).astype("int32")
...@@ -58,7 +63,8 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp): ...@@ -58,7 +63,8 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 2 self.win_size = 2
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int64") self.out_seq = np.array(out_seq).astype("int64")
...@@ -68,7 +74,8 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): ...@@ -68,7 +74,8 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
self.lod = [[9, 4, 11, 6]] self.lod = [[9, 4, 11, 6]]
self.win_size = 30 self.win_size = 30
self.pad_value = 0 self.pad_value = 0
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
self.pad_value)
self.out_seq = np.array(out_seq).astype("int32") self.out_seq = np.array(out_seq).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册