未验证 提交 7b84c580 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #12824 from JiayiFeng/dev_sequence_padding_op

Sequence pad op
......@@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
......
......@@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op)
if (WITH_GPU)
......
......@@ -18,65 +18,86 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T>
void CopyValidData(framework::Tensor* dst_tensor,
const framework::Tensor* src_tensor,
const framework::Vector<size_t>& seq_offsets,
int pad_seq_len, int step_width, bool norm_by_len,
CopyType type, PadLayout layout) {
int seq_num = seq_offsets.size() - 1;
const T* src_data = src_tensor->data<T>();
T* dst_data = dst_tensor->data<T>();
int seq_cpy_gap = step_width;
int pad_cpy_gap =
layout == kBatchLengthWidth ? step_width : seq_num * step_width;
for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) {
int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
PADDLE_ENFORCE_GE(
pad_seq_len, valid_seq_len,
"The padded sequence length can not be less than its original length.");
int seq_data_offset = seq_offsets[seq_idx] * step_width;
int pad_data_offset = layout == kBatchLengthWidth
? seq_idx * pad_seq_len * step_width
: seq_idx * step_width;
float scale = 1.0f / static_cast<float>(valid_seq_len);
for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) {
const T* src =
src_data + (type == kSeqToPad ? seq_data_offset : pad_data_offset);
T* dst =
dst_data + (type == kSeqToPad ? pad_data_offset : seq_data_offset);
memcpy(dst, src, step_width * sizeof(T));
if (norm_by_len) {
for (int i = 0; i < step_width; ++i) {
*(dst + i) *= scale;
}
}
seq_data_offset += seq_cpy_gap;
pad_data_offset += pad_cpy_gap;
}
}
}
template <typename T>
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor* padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding->dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
const int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* seq_data = seq.data<T>();
T* padding_data = padding->data<T>();
for (int64_t i = 0; i < max_sequence_length; ++i) {
for (int64_t j = 0; j < num_sequences; ++j) {
int64_t start_pos = abs_offset_lod[level][j];
int64_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
if (i < sequence_length) {
// i > 0 => sequence_length > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (int64_t k = 0; k < sequence_width; ++k) {
padding_data[(i * num_sequences + j) * sequence_width + k] =
seq_data[(start_pos + i) * sequence_width + k] * scale;
}
} else {
memset(padding_data + (i * num_sequences + j) * sequence_width, 0,
sequence_width * sizeof(T));
}
const framework::LoDTensor& seq_tensor,
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
const auto& seq_tensor_dims = seq_tensor.dims();
const auto& pad_tensor_dims = pad_tensor->dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'.");
// fill padding value
T* pad_data = pad_tensor->data<T>();
const T* pad_value_data = pad_value.data<T>();
if (pad_value.numel() == 1) {
for (int i = 0; i < pad_tensor->numel(); ++i) {
pad_data[i] = *pad_value_data;
}
} else {
for (int i = 0; i < pad_tensor->numel(); i += step_width) {
memcpy(pad_data + i, pad_value_data, step_width * sizeof(T));
}
}
CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kSeqToPad, layout);
}
};
......@@ -84,62 +105,35 @@ template <typename T>
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor* seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq->lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq->dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
const int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq->numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* padding_data = padding.data<T>();
T* seq_data = seq->data<T>();
for (int64_t i = 0; i < num_sequences; ++i) {
int64_t start_pos = abs_offset_lod[level][i];
int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
for (int64_t j = 0; j < sequence_length; ++j) {
// sequence_width > j > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (int64_t k = 0; k < sequence_width; ++k) {
seq_data[(start_pos + j) * sequence_width + k] =
padding_data[(j * num_sequences + i) * sequence_width + k] *
scale;
}
}
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
CopyValidData<T>(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kPadToSeq, layout);
}
};
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, double>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
......
......@@ -19,41 +19,32 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T, bool NormByTimes, bool Padding>
__global__ void SequencePaddingKernel(T* padding, T* sequence,
const size_t* sequence_start_positions,
const size_t sequence_width,
const size_t max_sequence_length,
const size_t num_sequences) {
size_t padding_idx = blockIdx.y;
size_t start_pos = sequence_start_positions[padding_idx];
size_t sequence_length =
sequence_start_positions[padding_idx + 1] - start_pos;
size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y;
size_t padding_base_idx =
(sequence_idx * num_sequences + padding_idx) * sequence_width;
size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width;
if (sequence_idx < sequence_length) {
T scale = NormByTimes ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i];
}
} else {
/* padding -> sequence */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i];
}
template <typename T, CopyType Type>
__global__ void SequencePaddingKernel(
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
const size_t step_width, bool norm_by_len, const PadLayout layout) {
size_t seq_idx = blockIdx.y;
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
size_t pad_data_offset = layout == kBatchLengthWidth
? (seq_idx * pad_seq_len + step_idx) * step_width
: (step_idx * seq_num + seq_idx) * step_width;
T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
const T* src_data =
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
if (step_idx < seq_len) {
float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = scale * src_data[i];
}
} else if (sequence_idx < max_sequence_length) {
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = 0;
}
} else if (step_idx < pad_seq_len && Type == kSeqToPad) {
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
}
}
}
......@@ -62,74 +53,59 @@ template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor* padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding->dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
TensorCopy(seq, context.GetPlace(), context, padding);
padding->Resize(padding_dims);
const framework::LoDTensor& seq_tensor,
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
const auto& seq_tensor_dims = seq_tensor.dims();
const auto& pad_tensor_dims = pad_tensor->dims();
int max_seq_len = MaximumSequenceLength(seq_offsets);
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
PADDLE_ENFORCE_GE(pad_seq_len, max_seq_len,
"The pad_seq_len must be equal to or greater than the "
"original max sequence length.");
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'.");
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
pad_tensor->Resize(pad_tensor_dims);
return;
}
const int64_t kBlockSize = 512;
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
const T* seq_data = seq.data<T>();
T* padding_data = padding->data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data),
abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences);
}
const T* seq_data = seq_tensor.data<T>();
T* pad_data = pad_tensor->data<T>();
const T* pad_value_data = pad_value.data<T>();
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
}
};
......@@ -137,79 +113,62 @@ template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor* seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq->lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq->dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq->numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
TensorCopy(padding, context.GetPlace(), context, seq);
seq->Resize(seq_dims);
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
int max_seq_len = MaximumSequenceLength(seq_offsets);
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
seq_tensor->Resize(seq_tensor_dims);
return;
}
const int64_t kBlockSize = 512;
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
const T* padding_data = padding.data<T>();
T* seq_data = seq->data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data,
abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
max_sequence_length, num_sequences);
}
const T* pad_data = pad_tensor.data<T>();
T* seq_data = seq_tensor->data<T>();
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
seq_data, pad_data, nullptr, false,
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
}
};
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -22,17 +23,33 @@ namespace paddle {
namespace operators {
namespace math {
inline static size_t MaximumSequenceLength(const framework::LoD& lod,
const size_t level) {
const size_t num_sequences = lod[level].size() - 1;
size_t max_sequence_length = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
for (size_t i = 0; i < num_sequences; ++i) {
max_sequence_length =
std::max(max_sequence_length,
abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]);
enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth };
enum CopyType { kSeqToPad, kPadToSeq };
inline static size_t MaximumSequenceLength(
const framework::Vector<size_t>& seq_offset) {
size_t seq_num = seq_offset.size() - 1;
size_t max_seq_len = 0;
for (size_t i = 0; i < seq_num; ++i) {
max_seq_len = std::max(max_seq_len, seq_offset[i + 1] - seq_offset[i]);
}
return max_sequence_length;
return max_seq_len;
}
inline static void CheckDims(const framework::DDim& seq_tensor_dims,
const framework::DDim& pad_tensor_dims,
const framework::Vector<size_t>& seq_offset,
int64_t padded_seq_len, int64_t step_width,
const PadLayout& layout) {
PADDLE_ENFORCE_EQ(static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back(),
"Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences.");
PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
seq_tensor_dims.size() == pad_tensor_dims.size(),
"pad_tensor's rank should be 1 greater than seq_tensor's "
"rank, or be equal with it.");
}
/*
......@@ -64,15 +81,22 @@ inline static size_t MaximumSequenceLength(const framework::LoD& lod,
template <typename DeviceContext, typename T>
class PaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
framework::Tensor* padding, bool norm_by_times);
void operator()(const DeviceContext& context,
const framework::LoDTensor& seq_tensor,
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth);
};
template <typename DeviceContext, typename T>
class UnpaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, framework::LoDTensor* seq,
const framework::Tensor& padding, bool norm_by_times);
void operator()(const DeviceContext& context,
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth);
};
} // namespace math
......
......@@ -23,7 +23,9 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
paddle::framework::LoDTensor cpu_seq_back;
paddle::framework::LoDTensor seq;
paddle::framework::LoDTensor seq_back;
paddle::framework::Tensor padding;
paddle::framework::LoDTensor padding;
paddle::framework::LoDTensor cpu_pad_value;
paddle::framework::LoDTensor pad_value;
const size_t level = lod.size() - 1;
auto seq_dims =
......@@ -46,20 +48,33 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
}
const size_t max_sequence_length =
paddle::operators::math::MaximumSequenceLength(lod, level);
paddle::operators::math::MaximumSequenceLength(lod[level]);
const size_t num_sequences = lod[level].size() - 1;
auto padding_dims =
paddle::framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
padding.mutable_data<T>(padding_dims, *place);
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, paddle::platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
if (paddle::platform::is_cpu_place(*place)) {
pad_value = cpu_pad_value;
} else {
TensorCopySync(cpu_pad_value, *place, &pad_value);
}
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq, &padding, false);
*context, seq, &padding, pad_value, -1, 0, false,
paddle::operators::math::kLengthBatchWidth);
seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, *place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
*context, &seq_back, padding, false);
*context, padding, &seq_back, -1, 0, false,
paddle::operators::math::kLengthBatchWidth);
if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h"
namespace paddle {
namespace operators {
class SequencePadOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PadValue"),
"Input(PadValue) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequencePadOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The rank of Input(x) can't be less than 2.");
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) ||
pad_value_dims == time_step_dims,
"The Input(PadValue) must be a scalar or a tensor whose "
"shape equals to time steps in sequences");
int out_dim_0 = -1;
int out_dim_1 = -1;
if (ctx->IsRuntime()) {
// run time
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
"The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
int seq_num = x_lod_0.size() - 1;
int max_seq_len = math::MaximumSequenceLength(x_lod_0);
int padded_length = ctx->Attrs().Get<int>("padded_length");
if (padded_length == -1) {
padded_length = max_seq_len;
}
PADDLE_ENFORCE_GE(padded_length, max_seq_len,
"The Attr(padded_length) must be -1 or an int greater "
"than the length of the longest original sequence.");
out_dim_0 = seq_num;
out_dim_1 = padded_length;
} else {
// compile time
framework::VarDesc* x_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
}
std::vector<int> out_dims_vec{out_dim_0, out_dim_1};
auto time_step_dims_vec = framework::vectorize2int(time_step_dims);
out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(),
time_step_dims_vec.end());
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
}
};
class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information.");
AddInput("PadValue",
"(LoDTensor), this Tensor holds values that will be fill into "
"padded steps. It can be a scalar or a tensor whose shape equals "
"to time steps in sequences. If it's a scalar, it will be "
"automatically broadcasted to the shape of time step.");
AddOutput(
"Out",
"(LoDTensor) The output vairable, which contains padded sequences.");
AddAttr<int>(
"padded_length",
"The length of padded sequences. It can be setted to -1 or "
"any positive int. When it is -1, all sequences will be padded up to "
"the length of the longest one among them; when it a certain positive "
"value, it must be greater than the length of the longest original "
"sequence.")
.SetDefault(-1);
AddComment(R"DOC(
Sequence Pad Operator
This operator pads sequences in a same batch to a consistent length.
The length is specified by attribute 'padded_length'. New elements,
whose values are specified by input 'PadValue', will be appended to
the end of each sequence, to make their final lengths consistent.
Following are cases to better explain how this works:
Case 1:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [a, b, c, d, e]
and Input(PadValue):
PadValue.data = [0]
and attribite 'padded_length' = 4,
then we get LoDTensor:
Out.data = [[a, b, 0, 0],
[c, d, e, 0]]
Case 2:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
and Input(PadValue):
PadValue.data = [0]
and attribite 'padded_length' = -1, which mean using the length
of longest input sequence(3 in this case),
then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [0, 0]],
[[c1, c2], [d1, d2], [e1, e2]]]
Case 3:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
and Input(PadValue):
PadValue.data = [p1, p2]
and attribite 'padded_length' = -1, which mean using the length
of longest input sequence(3 in this case),
then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
[[c1, c2], [d1, d2], [e1, e2]]]
)DOC");
}
};
class SequencePadGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_pad_grad,
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequencePadGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequencePadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_pad_grad,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequencePadGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
template <typename DeviceContext, typename T>
class SequencePadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
int padded_length = ctx.Attr<int>("padded_length");
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
padded_length, 0, false, math::kBatchLengthWidth);
}
};
template <typename DeviceContext, typename T>
class SequencePadGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(ctx.GetPlace());
int padded_length = ctx.Attr<int>("padded_length");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *d_out, d_x,
padded_length, 0, false, math::kBatchLengthWidth);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -153,17 +153,29 @@ class WarpCTCKernel : public framework::OpKernel<T> {
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
Tensor warpctc_logits;
LoDTensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod, level);
math::MaximumSequenceLength(logits_lod[level]);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
} else {
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
false);
pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
......@@ -209,15 +221,15 @@ template <typename DeviceContext, typename T>
class WarpCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), logits_grad,
*warpctc_grad, norm_by_times);
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
......
......@@ -54,6 +54,7 @@ __all__ = [
'conv2d_transpose',
'conv3d_transpose',
'sequence_expand',
'sequence_pad',
'lstm_unit',
'reduce_sum',
'reduce_mean',
......@@ -2656,6 +2657,51 @@ def sequence_expand(x, y, ref_level=-1, name=None):
return tmp
@templatedoc()
def sequence_pad(x, pad_value, maxlen=None):
"""
${comment}
Args:
x(Variable): Input variable which should contain lod information.
pad_value(Variable): The Variable that holds values that will be fill
into padded steps. It can be a scalar or a tensor whose shape
equals to time steps in sequences. If it's a scalar, it will be
automatically broadcasted to the shape of time step.
maxlen(int, default None): The length of padded sequences. It can be
None or any positive int. When it is None, all sequences will be
padded up to the length of the longest one among them; when it a
certain positive value, it must be greater than the length of the
longest original sequence."
Returns:
Variable: The padded sequence batch. All sequences has the same length.
Examples:
.. code-block:: python
import numpy
x = fluid.layers.data(name='y', shape=[10, 5],
dtype='float32', lod_level=1)
pad_value = fluid.layers.assign(input=numpy.array([0]))
out = fluid.layers.sequence_pad(x=x, pad_value=pad_value)
"""
helper = LayerHelper('sequence_pad', input=x, **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
if maxlen is None:
maxlen = -1
helper.append_op(
type='sequence_pad',
inputs={'X': x,
'PadValue': pad_value},
outputs={'Out': out},
attrs={'padded_length': maxlen})
return out
def beam_search(pre_ids,
pre_scores,
ids,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestSequencePadOp(OpTest):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = -1
self.dtype = 'float32'
def set_data(self):
x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype)
pad_value_data = np.array(self.pad_value).astype(self.dtype)
self.inputs = {
'X': (x_data, self.x_len_lod),
'PadValue': pad_value_data
}
self.attrs = {'padded_length': self.padded_length}
def compute(self):
# get padded length
padded_length = self.padded_length
x_len_lod_0 = self.x_len_lod[0]
if padded_length == -1:
max_seq_len = 0
for l in x_len_lod_0:
max_seq_len = max(max_seq_len, l)
padded_length = max_seq_len
# do padding
x_data = self.inputs['X'][0]
pad_value_data = self.inputs['PadValue']
if pad_value_data.shape == (1, ):
pad_value_data = np.broadcast_to(
pad_value_data, shape=x_data.shape[1:])
padded_sequences = []
start_idx = 0
for l in x_len_lod_0:
end_idx = start_idx + l
seq = x_data[start_idx:end_idx]
to_pad_len = padded_length - l
for _ in range(to_pad_len):
seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0)
padded_sequences.append(seq)
start_idx = end_idx
out_data = np.array(padded_sequences)
self.outputs = {'Out': out_data}
def setUp(self):
self.op_type = 'sequence_pad'
self.set_attr()
self.set_data()
self.compute()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSequencePadOp2(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0, 2.0, 3.0, 4.0]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp3(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = 7
self.dtype = 'float32'
class TestSequencePadOp4(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 4]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0, 2.0, 3.0, 4.0]
self.padded_length = 7
self.dtype = 'float32'
class TestSequencePadOp5(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp6(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [[1.0, 2.0], [3.0, 4.0]]
self.padded_length = -1
self.dtype = 'float32'
class TestSequencePadOp7(TestSequencePadOp):
def set_attr(self):
self.x_shape = [12, 2, 2]
self.x_len_lod = [[2, 3, 4, 3]]
self.pad_value = [1.0]
self.padded_length = 7
self.dtype = 'float32'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册