From 3c375751f8a8983257ea7f7e6086ab3a5fb555e0 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sat, 20 Apr 2019 09:45:53 +0800 Subject: [PATCH] Support seq len equal to 0 in sequence ops (#16935) * Support seq len equal to 0 in sequence ops test=develop * Add more test cases * Fix some comments test=develop * Fix py3 error test=develop --- paddle/fluid/framework/lod_tensor.cc | 4 +- paddle/fluid/framework/lod_tensor.h | 4 +- paddle/fluid/operators/crf_decoding_op.h | 1 + paddle/fluid/operators/math/context_project.h | 8 +++ .../operators/math/detail/lstm_gpu_kernel.h | 6 +- .../sequence_ops/sequence_concat_op.h | 5 +- .../sequence_ops/sequence_enumerate_op.h | 2 + .../sequence_ops/sequence_expand_op.h | 1 + .../sequence_ops/sequence_slice_op.h | 8 ++- .../tests/unittests/test_crf_decoding_op.py | 21 +++++-- .../tests/unittests/test_edit_distance_op.py | 54 +++++++++++------ .../fluid/tests/unittests/test_gru_op.py | 28 ++++++++- .../unittests/test_linear_chain_crf_op.py | 5 +- .../fluid/tests/unittests/test_lstm_op.py | 21 ++++++- .../fluid/tests/unittests/test_lstmp_op.py | 10 ++++ .../fluid/tests/unittests/test_seq_conv.py | 19 +++++- .../tests/unittests/test_sequence_concat.py | 47 ++++++++++++--- .../unittests/test_sequence_enumerate_op.py | 11 ++++ .../tests/unittests/test_sequence_erase_op.py | 15 +++++ .../tests/unittests/test_sequence_expand.py | 18 ++++++ .../unittests/test_sequence_expand_as.py | 9 +++ .../tests/unittests/test_sequence_pad_op.py | 9 +++ .../tests/unittests/test_sequence_reshape.py | 58 +++++++++---------- .../tests/unittests/test_sequence_reverse.py | 12 ++++ .../unittests/test_sequence_scatter_op.py | 31 ++++++++-- .../tests/unittests/test_sequence_slice_op.py | 24 ++++++++ .../unittests/test_sequence_softmax_op.py | 36 +++++++++--- .../tests/unittests/test_sequence_unpad_op.py | 14 +++++ 28 files changed, 389 insertions(+), 92 deletions(-) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index f46bdf96b..2b4683f9e 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -158,7 +158,7 @@ bool CheckLoD(const LoD &in, int tensor_height) { if (level.size() < 2) return false; // check: the first offset(the begin offset) of each level should be 0. if (level.front() != 0) return false; - // check: all the offsets in a level should be ascending(allow same items) + // check: all the offsets in a level should be non-descending if (!std::is_sorted(level.begin(), level.end())) { return false; } @@ -182,7 +182,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height) { if (in.empty()) return true; for (const auto &level : in) { // check: all the offsets in a level should be ascending(no same items - // allows). + // allowed). if (!std::is_sorted(level.begin(), level.begin(), [](size_t a, size_t b) { if (a < b) return true; return false; diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index fb6e781fd..5e20ba7c1 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -79,7 +79,7 @@ bool operator==(const LoD& a, const LoD& b); * * It will check two things: * - * 1. all the offsets in a level should be ascending(no same items allows). + * 1. all the offsets in a level should be non-descending. * 2. there should be more than 2 offsets existing in each level. * 3. the higher level's last offset should equals the lower level's size-1. * 4. the first offset(the begin offset) of each level should be 0. @@ -95,7 +95,7 @@ bool CheckLoD(const LoD& in, int tensor_height = -1); * - Empty lod is treated as valid. * * It will check two things: - * 1. all the offsets in a level should be ascending(no same items allows) + * 1. all the offsets in a level should be ascending(no same items allowed). * 2. there should be more than 2 offsets existing in each level. * 3. the first offset of each level should be 0, and the last should be the * same(the height of underlying tensor) or `tensor_height` if diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h index d6b54038e..13a587dc4 100644 --- a/paddle/fluid/operators/crf_decoding_op.h +++ b/paddle/fluid/operators/crf_decoding_op.h @@ -46,6 +46,7 @@ class CRFDecodingOpKernel : public framework::OpKernel { math::SetConstant()( ctx.template device_context(), decoded_path, 0); for (size_t i = 0; i < seq_num; ++i) { + if (lod[level][i] == lod[level][i + 1]) continue; int start_pos = static_cast(lod[level][i]); int end_pos = static_cast(lod[level][i + 1]); Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos); diff --git a/paddle/fluid/operators/math/context_project.h b/paddle/fluid/operators/math/context_project.h index d6a4793a8..f60943695 100644 --- a/paddle/fluid/operators/math/context_project.h +++ b/paddle/fluid/operators/math/context_project.h @@ -104,6 +104,8 @@ class ContextProjectFunctor { sequence_width = in.dims()[1]; for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + if (lod_level_0[i] == lod_level_0[i + 1]) continue; + input_row_begin = (context_start > 0) ? static_cast(lod_level_0[i]) + context_start : static_cast(lod_level_0[i]); @@ -134,6 +136,8 @@ class ContextProjectFunctor { if (padding_trainable) { PADDLE_ENFORCE_NOT_NULL(padding_data); for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + if (lod_level_0[i] == lod_level_0[i + 1]) continue; + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), static_cast(lod_level_0[i + 1])); @@ -216,6 +220,8 @@ class ContextProjectGradFunctor { if (input_grad) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + if (lod_level_0[i] == lod_level_0[i + 1]) continue; + input_row_begin = (context_start > 0) ? static_cast(lod_level_0[i]) + context_start : static_cast(lod_level_0[i]); @@ -248,6 +254,8 @@ class ContextProjectGradFunctor { if (pad_grad) { if (padding_trainable) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + if (lod_level_0[i] == lod_level_0[i + 1]) continue; + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), static_cast(lod_level_0[i + 1])); diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index e0ca9e7f5..24885d370 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -197,9 +197,9 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, threads = dim3(frame_per_block, 1); grid = dim3(frame_blocks, 1); } else { - /* frame_per_block = 32 batch_per_block = 32 */ - threads = dim3(32, 32); - grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); + /* frame_per_block = 32 batch_per_block = 16 */ + threads = dim3(32, 16); + grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16); } auto stream = diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h index f9b2ed384..dd31f9f17 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h @@ -36,7 +36,9 @@ inline framework::LoD ConcatLoD(const Container &xs, for (size_t j = 0; j < xs.size(); ++j) { auto &x_lod = xs[j].get().lod()[0]; const framework::Tensor &tensor = xs[j].get(); - xs_in_order->emplace_back(tensor.Slice(x_lod[i - 1], x_lod[i])); + if (x_lod[i - 1] < x_lod[i]) { + xs_in_order->emplace_back(tensor.Slice(x_lod[i - 1], x_lod[i])); + } sum += x_lod[i]; } result[i] = sum; @@ -102,6 +104,7 @@ class SeqConcatGradKernel : public framework::OpKernel { framework::LoDTensor *dx = dxs[j]; auto &x_lod = x->lod()[0]; + if (x_lod[i - 1] == x_lod[i]) continue; auto prev_lod = x_lod[i - 1]; auto next_lod = x_lod[i]; diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h index 6a1eb6e62..6c5a2e968 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h @@ -47,8 +47,10 @@ class SequenceEnumerateKernel : public framework::OpKernel { out->set_lod(in->lod()); auto out_data = out->mutable_data(context.GetPlace()); for (size_t i = 0; i < lod0.size() - 1; ++i) { + if (lod0[i] == lod0[i + 1]) continue; int start = lod0[i]; int end = lod0[i + 1]; + int copy_size = win_size < end - start + 1 ? win_size : end - start + 1; int mid = end + 1 - copy_size; int pad_num = win_size - copy_size; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.h b/paddle/fluid/operators/sequence_ops/sequence_expand_op.h index 9228c8131..fac63f3fa 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.h @@ -160,6 +160,7 @@ struct SequenceExpandGradFunctor { int x_start = x_lod[i - 1]; int x_end = x_lod[i]; int x_seq_len = x_end - x_start; + if (x_seq_len == 0) continue; auto dx_sub = dx->Slice(x_start, x_end); dx_sub.Resize(flatten_to_1d(dx_sub.dims())); int dout_end = dout_offset + repeat_num * x_seq_len; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.h b/paddle/fluid/operators/sequence_ops/sequence_slice_op.h index 4bded0efb..146b5cc9b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.h @@ -76,9 +76,9 @@ class SequenceSliceOpKernel : public framework::OpKernel { for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_LE(0, offset_data[i], - "The offset[%d] must greater than zero.", i); - PADDLE_ENFORCE_LT(0, length_data[i], - "The length[%d] must greater than zero.", i); + "The offset[%d] must be nonnegative.", i); + PADDLE_ENFORCE_LE(0, length_data[i], + "The length[%d] must be nonnegative.", i); PADDLE_ENFORCE_LE(lod[0][i] + offset_data[i] + length_data[i], lod[0][i + 1], "The target tensor's length overflow."); } @@ -95,6 +95,7 @@ class SequenceSliceOpKernel : public framework::OpKernel { size_t out_offset = 0; for (size_t i = 0; i < n; ++i) { + if (length_data[i] == 0) continue; Tensor in_t = in->Slice( static_cast(lod[0][i] + offset_data[i]), static_cast(lod[0][i] + offset_data[i] + length_data[i])); @@ -144,6 +145,7 @@ class SequenceSliceGradOpKernel : public framework::OpKernel { static_cast(0)); for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { + if (length_data[i] == 0) continue; Tensor out_grad_t = out_grad->Slice(static_cast(out_lod[0][i]), static_cast(out_lod[0][i + 1])); diff --git a/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py b/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py index 51bd1300e..89af72107 100644 --- a/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py +++ b/python/paddle/fluid/tests/unittests/test_crf_decoding_op.py @@ -128,12 +128,15 @@ class TestCRFDecodingOp2(OpTest): ground truth being given. """ + def init_lod(self): + self.lod = [[1, 2, 3, 4]] + def setUp(self): self.op_type = "crf_decoding" TAG_NUM = 5 - lod = [[1, 2, 3, 4]] - total_len = sum(lod[-1]) + self.init_lod() + total_len = sum(self.lod[-1]) transition = np.repeat( np.arange( TAG_NUM, dtype="float64").reshape(1, TAG_NUM), @@ -152,9 +155,9 @@ class TestCRFDecodingOp2(OpTest): expected_output = (labels == predicted_labels).astype("int64") self.inputs = { - "Emission": (emission, lod), + "Emission": (emission, self.lod), "Transition": transition, - "Label": (labels, lod) + "Label": (labels, self.lod) } self.outputs = {"ViterbiPath": expected_output} @@ -163,5 +166,15 @@ class TestCRFDecodingOp2(OpTest): self.check_output() +class TestCRFDecodingOp3(TestCRFDecodingOp2): + def init_lod(self): + self.lod = [[1, 0, 0, 4]] + + +class TestCRFDecodingOp4(TestCRFDecodingOp2): + def init_lod(self): + self.lod = [[0, 2, 3, 0]] + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py index 4d0352302..0a334197a 100644 --- a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py +++ b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py @@ -58,10 +58,10 @@ class TestEditDistanceOp(OpTest): x2 = np.array([[12, 4, 7, 8]]).astype("int64") x1 = np.transpose(x1) x2 = np.transpose(x2) - x1_lod = [1, 4] - x2_lod = [3, 1] + self.x1_lod = [1, 4] + self.x2_lod = [3, 1] - num_strs = len(x1_lod) + num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") sequence_num = np.array(2).astype("int64") @@ -69,23 +69,26 @@ class TestEditDistanceOp(OpTest): x2_offset = 0 for i in range(0, num_strs): distance[i] = Levenshtein( - hyp=x1[x1_offset:(x1_offset + x1_lod[i])], - ref=x2[x2_offset:(x2_offset + x2_lod[i])]) - x1_offset += x1_lod[i] - x2_offset += x2_lod[i] + hyp=x1[x1_offset:(x1_offset + self.x1_lod[i])], + ref=x2[x2_offset:(x2_offset + self.x2_lod[i])]) + x1_offset += self.x1_lod[i] + x2_offset += self.x2_lod[i] if normalized is True: - len_ref = x2_lod[i] + len_ref = self.x2_lod[i] distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} - self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.inputs = {'Hyps': (x1, [self.x1_lod]), 'Refs': (x2, [self.x2_lod])} self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): self.check_output() -class TestEditDistanceOpNormalized(OpTest): +class TestEditDistanceOpNormalizedCase0(OpTest): + def reset_config(self): + pass + def setUp(self): self.op_type = "edit_distance" normalized = True @@ -93,10 +96,11 @@ class TestEditDistanceOpNormalized(OpTest): x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64") x1 = np.transpose(x1) x2 = np.transpose(x2) - x1_lod = [1, 2, 3] - x2_lod = [2, 1, 2] + self.x1_lod = [3, 0, 3] + self.x2_lod = [2, 1, 2] + self.reset_config() - num_strs = len(x1_lod) + num_strs = len(self.x1_lod) distance = np.zeros((num_strs, 1)).astype("float32") sequence_num = np.array(3).astype("int64") @@ -104,21 +108,33 @@ class TestEditDistanceOpNormalized(OpTest): x2_offset = 0 for i in range(0, num_strs): distance[i] = Levenshtein( - hyp=x1[x1_offset:(x1_offset + x1_lod[i])], - ref=x2[x2_offset:(x2_offset + x2_lod[i])]) - x1_offset += x1_lod[i] - x2_offset += x2_lod[i] + hyp=x1[x1_offset:(x1_offset + self.x1_lod[i])], + ref=x2[x2_offset:(x2_offset + self.x2_lod[i])]) + x1_offset += self.x1_lod[i] + x2_offset += self.x2_lod[i] if normalized is True: - len_ref = x2_lod[i] + len_ref = self.x2_lod[i] distance[i] = distance[i] / len_ref self.attrs = {'normalized': normalized} - self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])} + self.inputs = {'Hyps': (x1, [self.x1_lod]), 'Refs': (x2, [self.x2_lod])} self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): self.check_output() +class TestEditDistanceOpNormalizedCase1(TestEditDistanceOpNormalizedCase0): + def reset_config(self): + self.x1_lod = [0, 6, 0] + self.x2_lod = [2, 1, 2] + + +class TestEditDistanceOpNormalizedCase2(TestEditDistanceOpNormalizedCase0): + def reset_config(self): + self.x1_lod = [0, 0, 6] + self.x2_lod = [2, 2, 1] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index c66d59ace..17af1d88d 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -82,9 +82,9 @@ def gru( hidden = np.zeros((T, D), dtype=dtype) idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse) - h_p = h0[sorted_seqs] + h_p = h0[[seq for seq in sorted_seqs if lod[0][seq] > 0]] + max_seq_len = len(idx_in_seq_list) - assert len(idx_in_seq_list[0]) == N end_idx = 0 for batch_idx in range(max_seq_len): x = input[idx_in_seq_list[batch_idx]] @@ -119,7 +119,6 @@ class TestGRUOp(OpTest): T = sum(self.lod[0]) N = len(self.lod[0]) - input = np.random.rand(T, 3 * self.D).astype(self.dtype) weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype) bias = np.random.rand( @@ -173,6 +172,13 @@ class TestGRUOp2(TestGRUOp): self.dtype = 'float32' +class TestGRUOp2Len0(TestGRUOp): + def set_confs(self): + self.D = 19 + self.lod = [[2, 0, 4]] + self.dtype = 'float32' + + class TestGRUOp2OriginMode(TestGRUOp): def set_confs(self): self.D = 19 @@ -180,6 +186,22 @@ class TestGRUOp2OriginMode(TestGRUOp): self.origin_mode = True +class TestGRUOp2OriginModeLen0(TestGRUOp): + def set_confs(self): + self.D = 19 + self.lod = [[0, 3, 4]] + self.dtype = 'float32' + self.origin_mode = True + + +class TestGRUOp2OriginModeLastLen0(TestGRUOp): + def set_confs(self): + self.D = 19 + self.lod = [[0, 3, 0]] + self.dtype = 'float32' + self.origin_mode = True + + class TestGRUOpNoInitial(TestGRUOp): def set_confs(self): self.with_h0 = False diff --git a/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py b/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py index 6e31e9204..b365e1642 100644 --- a/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py +++ b/python/paddle/fluid/tests/unittests/test_linear_chain_crf_op.py @@ -89,7 +89,8 @@ class LinearChainCrfForward(object): for i in range(self.seq_num): start = self.seq_start_positions[i] end = self.seq_start_positions[i + 1] - + if start >= end: + continue self.log_likelihood[i] = self._forward_a_sequence( self.x[start:end, :], self.x_row_max[start:end, :], self.x_exps[start:end, :], self.labels[start:end, :], @@ -110,7 +111,7 @@ class TestLinearChainCrfOp(OpTest): lod = [[]] seq_start_pos = [0] for i in range(SEQ_NUM): - lod[-1].append(random.randint(1, MAX_SEQ_LEN)) + lod[-1].append(random.randint(0, MAX_SEQ_LEN)) seq_start_pos.append(seq_start_pos[-1] + lod[-1][-1]) emission = np.random.uniform( -1, 1, [seq_start_pos[-1], TAG_NUM]).astype("float64") diff --git a/python/paddle/fluid/tests/unittests/test_lstm_op.py b/python/paddle/fluid/tests/unittests/test_lstm_op.py index 76a24123f..7ee33c6e9 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_op.py @@ -127,8 +127,11 @@ def lstm( class TestLstmOp(OpTest): - def set_argument(self): + def set_lod(self): self.lod = [[2, 3, 2]] + + def set_argument(self): + self.set_lod() self.D = 16 self.act_gate = 'sigmoid' @@ -142,7 +145,6 @@ class TestLstmOp(OpTest): def setUp(self): self.set_argument() self.op_type = 'lstm' - T = sum(self.lod[0]) N = len(self.lod[0]) @@ -198,6 +200,21 @@ class TestLstmOp(OpTest): ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) +class TestLstmOpCase1(TestLstmOp): + def set_lod(self): + self.lod = [[0, 3, 2]] + + +class TestLstmOpCase2(TestLstmOp): + def set_lod(self): + self.lod = [[0, 3, 0]] + + +class TestLstmOpCase3(TestLstmOp): + def set_lod(self): + self.lod = [[2, 0, 4]] + + # class TestLstmOpHasInitial(TestLstmOp): # def set_argument(self): # self.lod = [[2, 3, 2]] diff --git a/python/paddle/fluid/tests/unittests/test_lstmp_op.py b/python/paddle/fluid/tests/unittests/test_lstmp_op.py index 0645cfedb..70a0af6c9 100644 --- a/python/paddle/fluid/tests/unittests/test_lstmp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstmp_op.py @@ -305,5 +305,15 @@ class TestLstmpOpLinearProjection(TestLstmpOp): self.act_proj = 'identity' +class TestLstmpOpLen0Case1(TestLstmpOp): + def reset_argument(self): + self.lod = [[0, 4, 0]] + + +class TestLstmpOpLen0Case2(TestLstmpOp): + def reset_argument(self): + self.lod = [[2, 0, 3]] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_seq_conv.py b/python/paddle/fluid/tests/unittests/test_seq_conv.py index 2285e9496..da111f9b7 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_conv.py +++ b/python/paddle/fluid/tests/unittests/test_seq_conv.py @@ -204,7 +204,24 @@ class TestSeqProjectCase1(TestSeqProject): self.output_represention = 8 # output feature size -class TestSeqProjectCase2(TestSeqProject): +class TestSeqProjectCase2Len0(TestSeqProject): + def init_test_case(self): + self.input_row = 11 + self.context_start = -1 + self.context_length = 3 + self.padding_trainable = True + self.context_stride = 1 + + self.input_size = [self.input_row, 23] + offset_lod = [[0, 0, 4, 5, 5, 8, self.input_row, self.input_row]] + self.lod = [[]] + # convert from offset-based lod to length-based lod + for i in range(len(offset_lod[0]) - 1): + self.lod[0].append(offset_lod[0][i + 1] - offset_lod[0][i]) + self.output_represention = 8 # output feature size + + +class TestSeqProjectCase3(TestSeqProject): def init_test_case(self): self.input_row = 25 self.context_start = 2 diff --git a/python/paddle/fluid/tests/unittests/test_sequence_concat.py b/python/paddle/fluid/tests/unittests/test_sequence_concat.py index db99001ce..b4a40edc6 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_concat.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_concat.py @@ -20,19 +20,24 @@ from op_test import OpTest class TestSequenceConcat(OpTest): + def setLoD(self): + self.lod1 = [7, 3] + self.lod2 = [12, 8] + self.out_lod = [19, 11] + def setUp(self): x1 = np.random.random(size=(10, 80)) - lod1 = [7, 3] x2 = np.random.random(size=(20, 80)) - lod2 = [12, 8] + self.setLoD() - out = np.concatenate((x1[0:lod1[0]], x2[0:lod2[0]], x1[lod1[0]:], - x2[lod2[0]:])) - out_lod = [19, 11] + out = np.concatenate((x1[0:self.lod1[0]], x2[0:self.lod2[0]], + x1[self.lod1[0]:], x2[self.lod2[0]:])) self.op_type = "sequence_concat" - self.inputs = {'X': [("x1", (x1, [lod1])), ("x2", (x2, [lod2]))]} - self.outputs = {"Out": (out, [out_lod])} + self.inputs = { + 'X': [("x1", (x1, [self.lod1])), ("x2", (x2, [self.lod2]))] + } + self.outputs = {"Out": (out, [self.out_lod])} def test_output(self): self.check_output(1e-3) @@ -41,5 +46,33 @@ class TestSequenceConcat(OpTest): self.check_grad(inputs_to_check=['x1', 'x2'], output_names="Out") +class TestSequenceConcatCase2(TestSequenceConcat): + def setLoD(self): + self.lod1 = [10, 0] + self.lod2 = [12, 8] + self.out_lod = [22, 8] + + +class TestSequenceConcatCase3(TestSequenceConcat): + def setLoD(self): + self.lod1 = [10, 0] + self.lod2 = [20, 0] + self.out_lod = [30, 0] + + +class TestSequenceConcatCase4(TestSequenceConcat): + def setLoD(self): + self.lod1 = [0, 10] + self.lod2 = [0, 20] + self.out_lod = [0, 30] + + +class TestSequenceConcatCase5(TestSequenceConcat): + def setLoD(self): + self.lod1 = [0, 10] + self.lod2 = [20, 0] + self.out_lod = [20, 10] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py index 9814ec0a1..99bb33a0a 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -101,5 +101,16 @@ class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp): self.out_seq = np.array(out_seq).astype("int32") +class TestSequenceEnumerateOpLargePadValueSeqLen0(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[0, 14, 0, 16, 0]] + self.win_size = 5 + self.pad_value = 5 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py b/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py index b49249538..53bb301e9 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_erase_op.py @@ -79,6 +79,21 @@ class TestSequenceEraseOpInt64(OpTest): self.check_output() +class TestSequenceEraseOpInt64SeqLen0(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int64") + lod = [[0, 9, 0, 0, 10, 11, 0]] + tokens = [2, 3, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + class TestSequenceEraseOpEmpty(OpTest): def setUp(self): self.op_type = "sequence_erase" diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand.py b/python/paddle/fluid/tests/unittests/test_sequence_expand.py index d33a57f67..1e4d11197 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand.py @@ -116,5 +116,23 @@ class TestSequenceExpandCase4(TestSequenceExpand): self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} +class TestSequenceExpandCase5(TestSequenceExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [6, 1]).astype('float32') + y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') + y_lod = [[2, 4], [2, 2, 3, 0, 3, 3]] + self.inputs = {'X': x_data, 'Y': (y_data, y_lod)} + self.attrs = {'ref_level': 1} + + +class TestSequenceExpandCase6(TestSequenceExpand): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32') + x_lod = [[1, 1, 0, 1, 1]] + y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32') + y_lod = [[0, 2, 4, 2, 0]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py b/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py index 4ac97f7ed..30c487eea 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_expand_as.py @@ -65,6 +65,15 @@ class TestSequenceExpandAsCase1(TestSequenceExpandAs): class TestSequenceExpandAsCase2(TestSequenceExpandAs): + def set_data(self): + x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32') + x_lod = [[2, 3]] + y_data = np.random.uniform(0.1, 1, [10, 1]).astype('float32') + y_lod = [[0, 4, 0, 6, 0]] + self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} + + +class TestSequenceExpandAsCase3(TestSequenceExpandAs): def set_data(self): x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_lod = [[1]] diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py index 3067294e5..d5ab9e89f 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -132,5 +132,14 @@ class TestSequencePadOp7(TestSequencePadOp): self.dtype = 'float32' +class TestSequencePadOp8(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[0, 8, 0, 4, 0]] + self.pad_value = [1.0] + self.padded_length = 10 + self.dtype = 'float32' + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reshape.py b/python/paddle/fluid/tests/unittests/test_sequence_reshape.py index f11fa6c39..e2e7837da 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_reshape.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_reshape.py @@ -21,17 +21,17 @@ from op_test import OpTest class TestSequenceReshape(OpTest): + def init_data(self): + self.dimension = 12 + self.x_lod = [[4, 1, 3, 3]] + self.x = np.random.uniform(0.1, 1, [11, 24]).astype('float32') + def setUp(self): + self.init_data() self.op_type = 'sequence_reshape' - dimension = 12 - x_lod = [[4, 1, 3, 3]] - x = np.random.uniform(0.1, 1, [11, 24]).astype('float32') - - self.inputs = {'X': (x, x_lod)} - self.attrs = {'new_dim': dimension} - - out, out_lod = self.compute_output(x, x_lod, dimension) - + self.inputs = {'X': (self.x, self.x_lod)} + self.attrs = {'new_dim': self.dimension} + out, out_lod = self.compute_output(self.x, self.x_lod, self.dimension) self.outputs = {'Out': (out, out_lod)} def compute_output(self, x, x_lod, dimension): @@ -54,33 +54,31 @@ class TestSequenceReshape(OpTest): class TestSequenceReshape_reduce(TestSequenceReshape): - def setUp(self): - self.op_type = 'sequence_reshape' - dimension = 24 - x_lod = [[4, 2, 2, 4]] - x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') - - self.inputs = {'X': (x, x_lod)} - self.attrs = {'new_dim': dimension} - - out, out_lod = self.compute_output(x, x_lod, dimension) - - self.outputs = {'Out': (out, out_lod)} + def init_data(self): + self.dimension = 24 + self.x_lod = [[4, 2, 2, 4]] + self.x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') class TestSequenceReshape_same(TestSequenceReshape): - def setUp(self): - self.op_type = 'sequence_reshape' - dimension = 12 - x_lod = [[4, 2, 2, 4]] - x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + def init_data(self): + self.dimension = 12 + self.x_lod = [[4, 2, 2, 4]] + self.x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') - self.inputs = {'X': (x, x_lod)} - self.attrs = {'new_dim': dimension} - out, out_lod = self.compute_output(x, x_lod, dimension) +class TestSequenceReshape_reduce_seq_len0(TestSequenceReshape): + def init_data(self): + self.dimension = 24 + self.x_lod = [[0, 6, 0, 2, 4]] + self.x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') - self.outputs = {'Out': (out, out_lod)} + +class TestSequenceReshape_reduce_seq_len0_case1(TestSequenceReshape): + def init_data(self): + self.dimension = 24 + self.x_lod = [[0, 2, 8, 2, 0]] + self.x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reverse.py b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py index eebd25e09..09fb068ae 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_reverse.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py @@ -65,5 +65,17 @@ class TestSequenceReverse2(TestSequenceReverseBase): self.lod = [12] +class TestSequenceReverse3(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [3, 0, 6, 3] + + +class TestSequenceReverse3(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [0, 2, 10, 0] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py b/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py index f3d239e9c..4ffe2c2a1 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_scatter_op.py @@ -18,20 +18,26 @@ from op_test import OpTest class TestSequenceScatterOp(OpTest): + def init_lod(self): + return [[3, 5, 4]] + def setUp(self): self.op_type = "sequence_scatter" X_data = np.random.uniform(0.1, 1.0, [3, 6]).astype('float32') - Ids_data = np.array([[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], + Ids_data = np.array([[0], [1], [2], [5], [4], [3], [0], [1], [3], [2], [5], [4]]).astype('int64') - Ids_lod = [[3, 5, 4]] + Ids_lod = self.init_lod() + Updates_data = np.random.uniform(0.1, 1.0, [12, 1]).astype('float32') Updates_lod = Ids_lod Out_data = np.copy(X_data) - Out_data[0][Ids_data[0:3]] += Updates_data[0:3] - Out_data[1][Ids_data[3:8]] += Updates_data[3:8] - Out_data[2][Ids_data[8:]] += Updates_data[8:] + offset = 0 + for i in range(3): + Out_data[i][Ids_data[offset:(offset + Ids_lod[0][ + i])]] += Updates_data[offset:(offset + Ids_lod[0][i])] + offset += Ids_lod[0][i] self.inputs = { 'X': X_data, @@ -47,5 +53,20 @@ class TestSequenceScatterOp(OpTest): self.check_grad(['Updates'], 'Out', in_place=True) +class TestSequenceScatterOpSeqLen0(TestSequenceScatterOp): + def init_lod(self): + return [[6, 0, 6]] + + +class TestSequenceScatterOpSeqLen0Case1(TestSequenceScatterOp): + def init_lod(self): + return [[0, 6, 6]] + + +class TestSequenceScatterOpSeqLen0Case2(TestSequenceScatterOp): + def init_lod(self): + return [[6, 6, 0]] + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py b/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py index 156149008..9c5492b5b 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_slice_op.py @@ -59,5 +59,29 @@ class TestSequenceSliceOp(OpTest): self.check_grad(['X'], 'Out') +class TestSequenceSliceOpSeqlen0Case0(TestSequenceSliceOp): + def init_test_case(self): + self.x_dim = (100, 3, 2) + self.x_lod = [[20, 30, 0, 30, 20]] + self.offset = [[1], [2], [0], [4], [5]] + self.length = [[10], [8], [0], [4], [2]] + + +class TestSequenceSliceOpSeqlen0Case1(TestSequenceSliceOp): + def init_test_case(self): + self.x_dim = (100, 3, 2) + self.x_lod = [[0, 70, 0, 30, 0]] + self.offset = [[0], [2], [0], [4], [0]] + self.length = [[0], [8], [0], [4], [0]] + + +class TestSequenceSliceOpSeqlen0Case2(TestSequenceSliceOp): + def init_test_case(self): + self.x_dim = (100, 3, 2) + self.x_lod = [[0, 100, 0, 0, 0]] + self.offset = [[0], [2], [0], [0], [0]] + self.length = [[0], [8], [0], [0], [0]] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py index 3e00e7d95..154a53ee8 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py @@ -28,21 +28,26 @@ class TestSequenceSoftmaxOp(OpTest): self.init_op_type() x = np.random.uniform(0.1, 1, (11, 1)).astype("float32") - lod = [[4, 1, 3, 3]] - + self.init_lod() out = np.zeros((11, 1)).astype("float32") offset = 0 - for i in range(len(lod[0])): - sub_x = x[offset:offset + lod[0][i], :] - sub_x = sub_x.reshape(1, lod[0][i]) + for i in range(len(self.lod[0])): + if (self.lod[0][i] == 0): + continue + sub_x = x[offset:offset + self.lod[0][i], :] + sub_x = sub_x.reshape(1, self.lod[0][i]) sub_out = stable_softmax(sub_x) - out[offset:offset + lod[0][i], :] = sub_out.reshape(lod[0][i], 1) - offset += lod[0][i] + out[offset:offset + self.lod[0][i], :] = sub_out.reshape( + self.lod[0][i], 1) + offset += self.lod[0][i] - self.inputs = {"X": (x, lod)} + self.inputs = {"X": (x, self.lod)} self.outputs = {"Out": out} self.attrs = {'use_cudnn': self.use_cudnn, } + def init_lod(self): + self.lod = [[4, 1, 3, 3]] + def init_op_type(self): pass @@ -70,5 +75,20 @@ class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp): self.use_cudnn = True +class TestSequenceSoftmaxOpSeqLen0Case0(TestSequenceSoftmaxOp): + def init_lod(self): + self.lod = [[4, 0, 4, 3]] + + +class TestSequenceSoftmaxOpSeqLen0Case1(TestSequenceSoftmaxOp): + def init_lod(self): + self.lod = [[0, 4, 7, 0]] + + +class TestSequenceSoftmaxOpSeqLen0Case2(TestSequenceSoftmaxOp): + def init_lod(self): + self.lod = [[0, 0, 0, 11]] + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py index 673b0ea18..0e65108c7 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py @@ -71,5 +71,19 @@ class TestSequenceUnpadOp3(TestSequenceUnpadOp): self.dtype = "float64" +class TestSequenceUnpadOp4(TestSequenceUnpadOp): + def init(self): + self.length = [5, 0, 0, 4] + self.x_shape = (4, 5, 3, 3, 6) + self.dtype = "float64" + + +class TestSequenceUnpadOp4(TestSequenceUnpadOp): + def init(self): + self.length = [0, 4, 3, 0] + self.x_shape = (4, 5, 3, 3, 6) + self.dtype = "float64" + + if __name__ == '__main__': unittest.main() -- GitLab