提交 9a18e78e 编写于 作者: W wanghaox

update sequence slice op, fix some error

上级 29c25828
......@@ -75,14 +75,17 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
"the input of SequenceSliceOp.");
AddInput("Offset",
"(Tensor), "
"A vector<int> to describes offset for sub sequence item.");
"a vector<int> to describe the offset of every input sequence for "
"sub sequence item.");
AddInput("Length",
"(Tensor), "
"A vector<int> to describes length for sub sequence item.");
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), output of sequence slice Op.");
"(LoDTensor), The output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
The operator crop a subsequence from given sequence with given start offset and subsequence length.
It only supports sequence (LoD Tensor with level number is 1).
- Case:
......@@ -91,13 +94,13 @@ It only supports sequence (LoD Tensor with level number is 1).
c1, c2]
[d1, d2;
e1, e2]]
LoD(X) = {{0, 3, 5}}; Dims(X) = (4, 1, 2)
Offset = (0, 1); Length = (2, 1)
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
Offset = [0, 1]; Length = [2, 1]
Out = [[a1, a2;
b1, b2]
[e1, e2]]
LoD(Out) = {{0, 2, 3}}
LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
NOTE: The length of the input, offset and length should be the same. The offset start from 0.
)DOC");
}
......
......@@ -87,9 +87,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
auto out_dims = in->dims();
out_dims[0] = out_lod[0][out_lod[0].size() - 1];
out->Resize(out_dims);
out->set_lod(out_lod);
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), out, static_cast<T>(0));
auto in_stride = framework::stride(in->dims());
auto out_stride = framework::stride(out->dims());
......
......@@ -5,25 +5,32 @@ from op_test import OpTest
class TestSequenceSliceOp(OpTest):
def set_data(self):
self.init_test_case()
# only supprot one level LoD
x = np.random.random((100, 3, 2)).astype('float32')
lod = [[0, 20, 40, 60, 80, 100]]
offset = np.array([1, 2, 3, 4, 5]).flatten().astype("int64")
length = np.array([10, 8, 6, 4, 2]).flatten().astype("int64")
x = np.random.random(self.x_dim).astype('float32')
lod = self.x_lod
offset = np.array(self.offset).flatten().astype("int64")
length = np.array(self.length).flatten().astype("int64")
self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length}
outs = np.zeros((100, 3, 2)).astype('float32')
outs = [] #np.zeros((100, 3, 2)).astype('float32')
out_lod = [[0]]
out_lod_offset = 0
for i in range(len(offset)):
sub_x = x[lod[0][i] + offset[i]: lod[0]
[i] + offset[i] + length[i], :]
out_lod_offset = out_lod_offset + len(sub_x)
outs[out_lod[0][i]: out_lod_offset, :] = sub_x
outs.append(sub_x)
out_lod[0].append(out_lod_offset)
outs = np.concatenate(outs, axis=0)
self.outputs = {'Out': (outs, out_lod)}
def init_test_case(self):
self.x_dim = (100, 3, 2)
self.x_lod = [[0, 20, 40, 60, 80, 100]]
self.offset = [1, 2, 3, 4, 5]
self.length = [10, 8, 6, 4, 2]
def setUp(self):
self.op_type = "sequence_slice"
self.set_data()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册