diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index aad64406f7f60e1582a7c4c8c12ad97a127188ae..a9ca26062158007e20ce6f5e27af7decceef6d8c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) +paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e73d31562a93af4e9bdb3d5806e9182d0eea8167..2f8fadcefe5d5b6f428092914c7e999ec7524862 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) if (WITH_GPU) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index d63c6c4ed55331235188c1c750468d4e75b9b7f2..25f06a25a0638cbb394df58d35f88307941d117f 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,65 +18,86 @@ namespace paddle { namespace operators { namespace math { +template +void CopyValidData(framework::Tensor* dst_tensor, + const framework::Tensor* src_tensor, + const framework::Vector& seq_offsets, + int pad_seq_len, int step_width, bool norm_by_len, + CopyType type, PadLayout layout) { + int seq_num = seq_offsets.size() - 1; + const T* src_data = src_tensor->data(); + T* dst_data = dst_tensor->data(); + + int seq_cpy_gap = step_width; + int pad_cpy_gap = + layout == kBatchLengthWidth ? step_width : seq_num * step_width; + for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) { + int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + PADDLE_ENFORCE_GE( + pad_seq_len, valid_seq_len, + "The padded sequence length can not be less than its original length."); + int seq_data_offset = seq_offsets[seq_idx] * step_width; + int pad_data_offset = layout == kBatchLengthWidth + ? seq_idx * pad_seq_len * step_width + : seq_idx * step_width; + float scale = 1.0f / static_cast(valid_seq_len); + + for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) { + const T* src = + src_data + (type == kSeqToPad ? seq_data_offset : pad_data_offset); + T* dst = + dst_data + (type == kSeqToPad ? pad_data_offset : seq_data_offset); + memcpy(dst, src, step_width * sizeof(T)); + if (norm_by_len) { + for (int i = 0; i < step_width; ++i) { + *(dst + i) *= scale; + } + } + seq_data_offset += seq_cpy_gap; + pad_data_offset += pad_cpy_gap; + } + } +} + template class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor* padding, - bool norm_by_times) { - auto lod = seq.lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The LoD of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq.dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding->dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequence_length, num_sequences, sequence_width]."); - - const int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be the " - "maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be the " - "number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq.numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - const T* seq_data = seq.data(); - T* padding_data = padding->data(); - for (int64_t i = 0; i < max_sequence_length; ++i) { - for (int64_t j = 0; j < num_sequences; ++j) { - int64_t start_pos = abs_offset_lod[level][j]; - int64_t sequence_length = abs_offset_lod[level][j + 1] - start_pos; - if (i < sequence_length) { - // i > 0 => sequence_length > 0 - T scale = - norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; - for (int64_t k = 0; k < sequence_width; ++k) { - padding_data[(i * num_sequences + j) * sequence_width + k] = - seq_data[(start_pos + i) * sequence_width + k] * scale; - } - } else { - memset(padding_data + (i * num_sequences + j) * sequence_width, 0, - sequence_width * sizeof(T)); - } + const framework::LoDTensor& seq_tensor, + framework::LoDTensor* pad_tensor, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); + + // fill padding value + T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); + if (pad_value.numel() == 1) { + for (int i = 0; i < pad_tensor->numel(); ++i) { + pad_data[i] = *pad_value_data; + } + } else { + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value_data, step_width * sizeof(T)); } } + + CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kSeqToPad, layout); } }; @@ -84,62 +105,35 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor* seq, const framework::Tensor& padding, - bool norm_by_times) { - auto lod = seq->lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The LoD of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq->dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding.dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequnece_length, num_sequences, sequence_width]."); - - const int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be " - "the maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be " - "the number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq->numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - const T* padding_data = padding.data(); - T* seq_data = seq->data(); - for (int64_t i = 0; i < num_sequences; ++i) { - int64_t start_pos = abs_offset_lod[level][i]; - int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos; - for (int64_t j = 0; j < sequence_length; ++j) { - // sequence_width > j > 0 - T scale = - norm_by_times ? (1.0f / static_cast(sequence_length)) : 1.0f; - for (int64_t k = 0; k < sequence_width; ++k) { - seq_data[(start_pos + j) * sequence_width + k] = - padding_data[(j * num_sequences + i) * sequence_width + k] * - scale; - } - } + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + CopyValidData(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kPadToSeq, layout); } }; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 0956a0c17d387f4a174c7ed4e9b1b1f816dcf4ae..035e10dcbe4e2083723e47d7dda75ce267a9f141 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -19,41 +19,32 @@ namespace paddle { namespace operators { namespace math { -template -__global__ void SequencePaddingKernel(T* padding, T* sequence, - const size_t* sequence_start_positions, - const size_t sequence_width, - const size_t max_sequence_length, - const size_t num_sequences) { - size_t padding_idx = blockIdx.y; - size_t start_pos = sequence_start_positions[padding_idx]; - size_t sequence_length = - sequence_start_positions[padding_idx + 1] - start_pos; - - size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y; - size_t padding_base_idx = - (sequence_idx * num_sequences + padding_idx) * sequence_width; - size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width; - - if (sequence_idx < sequence_length) { - T scale = NormByTimes ? (1.0f / static_cast(sequence_length)) : 1.0f; - if (Padding) { - /* sequence -> padding */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i]; - } - } else { - /* padding -> sequence */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i]; - } +template +__global__ void SequencePaddingKernel( + T* dst, const T* src, const T* pad_value, bool is_constant_pad, + const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len, + const size_t step_width, bool norm_by_len, const PadLayout layout) { + size_t seq_idx = blockIdx.y; + size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + + size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width; + size_t pad_data_offset = layout == kBatchLengthWidth + ? (seq_idx * pad_seq_len + step_idx) * step_width + : (step_idx * seq_num + seq_idx) * step_width; + + T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset); + const T* src_data = + src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset); + + if (step_idx < seq_len) { + float scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = scale * src_data[i]; } - } else if (sequence_idx < max_sequence_length) { - if (Padding) { - /* sequence -> padding */ - for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) { - padding[padding_base_idx + i] = 0; - } + } else if (step_idx < pad_seq_len && Type == kSeqToPad) { + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i]; } } } @@ -62,74 +53,59 @@ template class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::LoDTensor& seq, framework::Tensor* padding, - bool norm_by_times) { - auto lod = seq.lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The lod of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq.dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding->dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequence_length, num_sequences, sequence_width]."); - - int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be the " - "maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be the " - "number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq.numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(seq, context.GetPlace(), context, padding); - padding->Resize(padding_dims); + const framework::LoDTensor& seq_tensor, + framework::LoDTensor* pad_tensor, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); + if (pad_seq_len == -1) { + pad_seq_len = max_seq_len; + } + PADDLE_ENFORCE_GE(pad_seq_len, max_seq_len, + "The pad_seq_len must be equal to or greater than the " + "original max sequence length."); + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + int seq_num = seq_offsets.size() - 1; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); + + if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { + TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); + pad_tensor->Resize(pad_tensor_dims); return; } - const int64_t kBlockSize = 512; + const int kBlockSize = 512; /* At least use 32 threads to copy sequence_width elements, * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; - size_t grid_dim_y = num_sequences; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); - const T* seq_data = seq.data(); - T* padding_data = padding->data(); - if (norm_by_times) { - SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } else { - SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } + const T* seq_data = seq_tensor.data(); + T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); + + SequencePaddingKernel<<>>( + pad_data, seq_data, pad_value_data, pad_value.numel() == 1, + seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; @@ -137,79 +113,62 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor* seq, const framework::Tensor& padding, - bool norm_by_times) { - auto lod = seq->lod(); - PADDLE_ENFORCE_GT(lod.size(), 0UL, - "The lod of LoDTensor seq should not be null."); - - const size_t level = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - - auto seq_dims = seq->dims(); - PADDLE_ENFORCE_EQ(seq_dims[0], - static_cast(abs_offset_lod[level].back()), - "The first dimension of LoDTensor seq should be " - "equal to the sum of all sequences's length."); - - auto padding_dims = padding.dims(); - PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL, - "The input padding should be a 3-D Tensor of shape " - "[max_sequnece_length, num_sequences, sequence_width]."); - - int64_t max_sequence_length = MaximumSequenceLength(lod, level); - PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length, - "The first dimension of Tensor padding should be " - "the maximum length of all sequences in LoDTensor seq."); - - const int64_t num_sequences = abs_offset_lod[level].size() - 1; - PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences, - "The second dimension of Tensor padding should be " - "the number of sequences in LoDTensor seq."); - - const int64_t sequence_width = seq->numel() / seq_dims[0]; - PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width, - "The third dimension of Tensor padding should be the " - "width of sequence in LoDTensor seq."); - - if (!norm_by_times && num_sequences == 1UL) { - TensorCopy(padding, context.GetPlace(), context, seq); - seq->Resize(seq_dims); + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); + if (pad_seq_len == -1) { + pad_seq_len = max_seq_len; + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + int seq_num = seq_offsets.size() - 1; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { + TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); + seq_tensor->Resize(seq_tensor_dims); return; } - const int64_t kBlockSize = 512; + const int kBlockSize = 512; /* At least use 32 threads to copy sequence_width elements, * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y; - size_t grid_dim_y = num_sequences; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); - const T* padding_data = padding.data(); - T* seq_data = seq->data(); - if (norm_by_times) { - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } else { - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width, - max_sequence_length, num_sequences); - } + const T* pad_data = pad_tensor.data(); + T* seq_data = seq_tensor->data(); + + SequencePaddingKernel<<>>( + seq_data, pad_data, nullptr, false, + seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index b56e6db1ebdac1a00561c07845c03bb8fbd8d35a..e752aa58979dddba4d010071d2c4b5dc3e0c6756 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -22,17 +23,33 @@ namespace paddle { namespace operators { namespace math { -inline static size_t MaximumSequenceLength(const framework::LoD& lod, - const size_t level) { - const size_t num_sequences = lod[level].size() - 1; - size_t max_sequence_length = 0; - framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); - for (size_t i = 0; i < num_sequences; ++i) { - max_sequence_length = - std::max(max_sequence_length, - abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]); +enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; + +enum CopyType { kSeqToPad, kPadToSeq }; + +inline static size_t MaximumSequenceLength( + const framework::Vector& seq_offset) { + size_t seq_num = seq_offset.size() - 1; + size_t max_seq_len = 0; + for (size_t i = 0; i < seq_num; ++i) { + max_seq_len = std::max(max_seq_len, seq_offset[i + 1] - seq_offset[i]); } - return max_sequence_length; + return max_seq_len; +} + +inline static void CheckDims(const framework::DDim& seq_tensor_dims, + const framework::DDim& pad_tensor_dims, + const framework::Vector& seq_offset, + int64_t padded_seq_len, int64_t step_width, + const PadLayout& layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), seq_offset.back(), + "Value of 1st dimension of the sequence tensor should be " + "equal to sum of lengths of all sequences."); + + PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || + seq_tensor_dims.size() == pad_tensor_dims.size(), + "pad_tensor's rank should be 1 greater than seq_tensor's " + "rank, or be equal with it."); } /* @@ -64,15 +81,22 @@ inline static size_t MaximumSequenceLength(const framework::LoD& lod, template class PaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, const framework::LoDTensor& seq, - framework::Tensor* padding, bool norm_by_times); + void operator()(const DeviceContext& context, + const framework::LoDTensor& seq_tensor, + framework::LoDTensor* pad_tensor, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth); }; template class UnpaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, framework::LoDTensor* seq, - const framework::Tensor& padding, bool norm_by_times); + void operator()(const DeviceContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index b0c201db0ccbe81d8f57cd984d2cdfd2f6a48f25..4f61b1029c65aedaf4fce771866964fe1d0d6112 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -23,7 +23,9 @@ void TestSequencePadding(const paddle::framework::LoD& lod, paddle::framework::LoDTensor cpu_seq_back; paddle::framework::LoDTensor seq; paddle::framework::LoDTensor seq_back; - paddle::framework::Tensor padding; + paddle::framework::LoDTensor padding; + paddle::framework::LoDTensor cpu_pad_value; + paddle::framework::LoDTensor pad_value; const size_t level = lod.size() - 1; auto seq_dims = @@ -46,20 +48,33 @@ void TestSequencePadding(const paddle::framework::LoD& lod, } const size_t max_sequence_length = - paddle::operators::math::MaximumSequenceLength(lod, level); + paddle::operators::math::MaximumSequenceLength(lod[level]); const size_t num_sequences = lod[level].size() - 1; auto padding_dims = paddle::framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); + padding.mutable_data(padding_dims, *place); + + T* pad_value_data = + cpu_pad_value.mutable_data({1}, paddle::platform::CPUPlace()); + *pad_value_data = static_cast(0); + if (paddle::platform::is_cpu_place(*place)) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, *place, &pad_value); + } + paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, false); + *context, seq, &padding, pad_value, -1, 0, false, + paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, &seq_back, padding, false); + *context, padding, &seq_back, -1, 0, false, + paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(*place)) { cpu_seq_back = seq_back; diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..44d73aa4076abfe15c906478702ac7c4a55303d4 --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_pad_op.h" + +namespace paddle { +namespace operators { + +class SequencePadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PadValue"), + "Input(PadValue) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequencePadOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "The rank of Input(x) can't be less than 2."); + auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size()); + auto pad_value_dims = ctx->GetInputDim("PadValue"); + PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) || + pad_value_dims == time_step_dims, + "The Input(PadValue) must be a scalar or a tensor whose " + "shape equals to time steps in sequences"); + + int out_dim_0 = -1; + int out_dim_1 = -1; + + if (ctx->IsRuntime()) { + // run time + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + const auto& x_lod = x_var->Get().lod(); + PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info."); + const auto& x_lod_0 = x_lod[0]; + PADDLE_ENFORCE_GE(x_lod_0.size(), 2, + "The Input(X)'s lod info is corrupted."); + PADDLE_ENFORCE_EQ( + x_dims[0], static_cast(x_lod_0.back()), + "The Input(X)'s lod info mismatches the actual tensor shape."); + + int seq_num = x_lod_0.size() - 1; + int max_seq_len = math::MaximumSequenceLength(x_lod_0); + int padded_length = ctx->Attrs().Get("padded_length"); + if (padded_length == -1) { + padded_length = max_seq_len; + } + PADDLE_ENFORCE_GE(padded_length, max_seq_len, + "The Attr(padded_length) must be -1 or an int greater " + "than the length of the longest original sequence."); + out_dim_0 = seq_num; + out_dim_1 = padded_length; + } else { + // compile time + framework::VarDesc* x_desc = + boost::get(ctx->GetInputVarPtrs("X")[0]); + PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); + } + + std::vector out_dims_vec{out_dim_0, out_dim_1}; + auto time_step_dims_vec = framework::vectorize2int(time_step_dims); + out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(), + time_step_dims_vec.end()); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); + } +}; + +class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, default LoDTensor) Input variable which " + "should contain lod information."); + AddInput("PadValue", + "(LoDTensor), this Tensor holds values that will be fill into " + "padded steps. It can be a scalar or a tensor whose shape equals " + "to time steps in sequences. If it's a scalar, it will be " + "automatically broadcasted to the shape of time step."); + AddOutput( + "Out", + "(LoDTensor) The output vairable, which contains padded sequences."); + AddAttr( + "padded_length", + "The length of padded sequences. It can be setted to -1 or " + "any positive int. When it is -1, all sequences will be padded up to " + "the length of the longest one among them; when it a certain positive " + "value, it must be greater than the length of the longest original " + "sequence.") + .SetDefault(-1); + AddComment(R"DOC( + Sequence Pad Operator + + This operator pads sequences in a same batch to a consistent length. + The length is specified by attribute 'padded_length'. New elements, + whose values are specified by input 'PadValue', will be appended to + the end of each sequence, to make their final lengths consistent. + + Following are cases to better explain how this works: + + Case 1: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [a, b, c, d, e] + and Input(PadValue): + PadValue.data = [0] + and attribite 'padded_length' = 4, + then we get LoDTensor: + Out.data = [[a, b, 0, 0], + [c, d, e, 0]] + + Case 2: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]] + and Input(PadValue): + PadValue.data = [0] + and attribite 'padded_length' = -1, which mean using the length + of longest input sequence(3 in this case), + then we get LoDTensor: + Out.data = [[[a1, a2], [b1, b2], [0, 0]], + [[c1, c2], [d1, d2], [e1, e2]]] + + Case 3: + + Given a 1-level LoDTensor input(X): + X.lod = [[0, 2, 5]] + X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]] + and Input(PadValue): + PadValue.data = [p1, p2] + and attribite 'padded_length' = -1, which mean using the length + of longest input sequence(3 in this case), + then we get LoDTensor: + Out.data = [[[a1, a2], [b1, b2], [p1, p2]], + [[c1, c2], [d1, d2], [e1, e2]]] + + )DOC"); + } +}; + +class SequencePadGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequencePadGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequencePadGradOp should not be null."); + + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_pad, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_pad_grad, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.cu b/paddle/fluid/operators/sequence_pad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ff8f81a2f0ec4a72befc3be2a5fc48c3a586c824 --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_pad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_pad, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_pad_grad, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5fc9da69d787ff3aeffa716689d44772ad8f7bd2 --- /dev/null +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_padding.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +class SequencePadOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + const auto* pad_value = ctx.Input("PadValue"); + + int padded_length = ctx.Attr("padded_length"); + + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), *x, out, *pad_value, + padded_length, 0, false, math::kBatchLengthWidth); + } +}; + +template +class SequencePadGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_x = ctx.Output(framework::GradVarName("X")); + if (d_x) { + const auto* d_out = ctx.Input(framework::GradVarName("Out")); + d_x->mutable_data(ctx.GetPlace()); + + int padded_length = ctx.Attr("padded_length"); + + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *d_out, d_x, + padded_length, 0, false, math::kBatchLengthWidth); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index ab70c1f0592d122ba248a101db487e64c0bdae6f..444265f58de732f07c5db2abd87811a063016866 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -153,17 +153,29 @@ class WarpCTCKernel : public framework::OpKernel { framework::make_ddim({static_cast(num_sequences), 1}); // warpctc needs sequences data stored in transposed padding format - Tensor warpctc_logits; + LoDTensor warpctc_logits; const size_t max_sequence_length = - math::MaximumSequenceLength(logits_lod, level); + math::MaximumSequenceLength(logits_lod[level]); auto warpctc_logits_dims = framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); + + LoDTensor cpu_pad_value; + T* pad_value_data = + cpu_pad_value.mutable_data({1}, platform::CPUPlace()); + *pad_value_data = static_cast(0); + LoDTensor pad_value; + if (platform::is_cpu_place(ctx.GetPlace())) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value); + } + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - false); + pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); @@ -209,15 +221,15 @@ template class WarpCTCGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* warpctc_grad = ctx.Input("WarpCTCGrad"); + auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), logits_grad, - *warpctc_grad, norm_by_times); + ctx.template device_context(), *warpctc_grad, + logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()( diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9a6aa14c4ffb8865344140381a410036749fa959..f98b18afa7cd7822269d04fa36fd1ce117a1741a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -54,6 +54,7 @@ __all__ = [ 'conv2d_transpose', 'conv3d_transpose', 'sequence_expand', + 'sequence_pad', 'lstm_unit', 'reduce_sum', 'reduce_mean', @@ -2656,6 +2657,51 @@ def sequence_expand(x, y, ref_level=-1, name=None): return tmp +@templatedoc() +def sequence_pad(x, pad_value, maxlen=None): + """ + ${comment} + + Args: + x(Variable): Input variable which should contain lod information. + pad_value(Variable): The Variable that holds values that will be fill + into padded steps. It can be a scalar or a tensor whose shape + equals to time steps in sequences. If it's a scalar, it will be + automatically broadcasted to the shape of time step. + maxlen(int, default None): The length of padded sequences. It can be + None or any positive int. When it is None, all sequences will be + padded up to the length of the longest one among them; when it a + certain positive value, it must be greater than the length of the + longest original sequence." + + Returns: + Variable: The padded sequence batch. All sequences has the same length. + + Examples: + .. code-block:: python + + import numpy + + x = fluid.layers.data(name='y', shape=[10, 5], + dtype='float32', lod_level=1) + pad_value = fluid.layers.assign(input=numpy.array([0])) + out = fluid.layers.sequence_pad(x=x, pad_value=pad_value) + """ + + helper = LayerHelper('sequence_pad', input=x, **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if maxlen is None: + maxlen = -1 + helper.append_op( + type='sequence_pad', + inputs={'X': x, + 'PadValue': pad_value}, + outputs={'Out': out}, + attrs={'padded_length': maxlen}) + return out + + def beam_search(pre_ids, pre_scores, ids, diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..471515c817541976a06eb024fa3d4f77b78f920d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -0,0 +1,131 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSequencePadOp(OpTest): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + def set_data(self): + x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype) + pad_value_data = np.array(self.pad_value).astype(self.dtype) + self.inputs = { + 'X': (x_data, self.x_len_lod), + 'PadValue': pad_value_data + } + self.attrs = {'padded_length': self.padded_length} + + def compute(self): + # get padded length + padded_length = self.padded_length + x_len_lod_0 = self.x_len_lod[0] + if padded_length == -1: + max_seq_len = 0 + for l in x_len_lod_0: + max_seq_len = max(max_seq_len, l) + padded_length = max_seq_len + + # do padding + x_data = self.inputs['X'][0] + pad_value_data = self.inputs['PadValue'] + if pad_value_data.shape == (1, ): + pad_value_data = np.broadcast_to( + pad_value_data, shape=x_data.shape[1:]) + padded_sequences = [] + start_idx = 0 + for l in x_len_lod_0: + end_idx = start_idx + l + seq = x_data[start_idx:end_idx] + to_pad_len = padded_length - l + for _ in range(to_pad_len): + seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0) + padded_sequences.append(seq) + start_idx = end_idx + + out_data = np.array(padded_sequences) + self.outputs = {'Out': out_data} + + def setUp(self): + self.op_type = 'sequence_pad' + self.set_attr() + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequencePadOp2(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp3(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp4(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp5(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp6(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [[1.0, 2.0], [3.0, 4.0]] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp7(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32'