From 33d1e565060e3d79f57de1a7e3b35e769571abfd Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 10 Jun 2019 16:43:03 +0800 Subject: [PATCH] Enable seq_pool op to accept len 0 input (#17284) * Enable seq_pool op to accept len 0 input test=develop * Update sequence_pool's api test=develop * Add more unittest cases for seq_pool op test=develop * Remove legacy comments test=develop * Don't use template in op maker test=develop --- paddle/fluid/API.spec | 2 +- .../fluid/operators/math/sequence_pooling.cc | 88 +++++-- .../fluid/operators/math/sequence_pooling.cu | 129 +++++---- .../fluid/operators/math/sequence_pooling.h | 5 +- .../sequence_ops/sequence_pool_op.cc | 5 + .../operators/sequence_ops/sequence_pool_op.h | 5 +- python/paddle/fluid/layers/nn.py | 36 ++- .../fluid/tests/unittests/test_seq_pool.py | 244 +++++++++++++----- 8 files changed, 357 insertions(+), 157 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 67f654e72f4..4ab89a7105b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -87,7 +87,7 @@ paddle.fluid.layers.chunk_eval (ArgSpec(args=['input', 'label', 'chunk_scheme', paddle.fluid.layers.sequence_conv (ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)), ('document', '3d8e8f3e0e1cf520156be37605e83ccd')) paddle.fluid.layers.conv2d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '8ca6121acd6d23cd8806a93f493c2e17')) paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)), ('document', '37042620f9bd3a2da6e5d3138b2f724b')) -paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,)), ('document', 'a194fb80614023f543df3949fbd0d0b8')) +paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82')) paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '19ef6f9cdd27feac8a1ae060f19c10b4')) paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5')) paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'bbd84e855e660cd1084bb71a2fd0cdaa')) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 7af44f2b2ca..011d45c3965 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -36,8 +36,8 @@ template class MaxSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& input, framework::Tensor* output, - framework::Tensor* index) { + const framework::LoDTensor& input, T pad_value, + framework::Tensor* output, framework::Tensor* index) { auto in_dims = input.dims(); auto out_dims = output->dims(); auto idx_dims = index->dims(); @@ -56,6 +56,13 @@ class MaxSeqPoolFunctor { int64_t num_seq = out_dims[0]; int64_t dim = output->numel() / num_seq; for (int64_t i = 0; i < num_seq; ++i) { + if (starts[i] == starts[i + 1]) { + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = pad_value; + max_index[i * dim + k] = -1; + } + continue; + } for (int64_t k = 0; k < dim; ++k) { out_data[i * dim + k] = in_data[starts[i] * dim + k]; max_index[i * dim + k] = starts[i]; @@ -77,8 +84,8 @@ template class MaxSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& input, framework::Tensor* output, - framework::Tensor* index) { + const framework::LoDTensor& input, T pad_value, + framework::Tensor* output, framework::Tensor* index) { auto in_dims = input.dims(); auto out_dims = output->dims(); PADDLE_ENFORCE_GT(in_dims.size(), 1); @@ -94,6 +101,12 @@ class MaxSeqPoolFunctor { int64_t num_seq = out_dims[0]; int64_t dim = output->numel() / num_seq; for (int64_t i = 0; i < num_seq; ++i) { + if (starts[i] == starts[i + 1]) { + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = pad_value; + } + continue; + } std::memcpy(&out_data[i * dim], &in_data[starts[i] * dim], dim * sizeof(T)); for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { @@ -134,6 +147,7 @@ class MaxSeqPoolGradFunctor { for (int64_t i = 0; i < num_seq; ++i) { for (int64_t j = 0; j < dim; ++j) { int step_id = max_index[i * dim + j]; + if (step_id == -1) continue; ig_data[step_id * dim + j] = og_data[i * dim + j]; } } @@ -144,7 +158,7 @@ template class LastSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& input, + const framework::LoDTensor& input, T pad_value, framework::Tensor* output) { // Create pointers to input and output data auto* in_data = input.data(); @@ -157,10 +171,16 @@ class LastSeqPoolFunctor { for (int i = 0; i < seq_num; ++i) { // Calculate the length of each sequence int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - // Point to the begin of next sequence - in_data += seq_len * item_size; - // Copy the last item of sequence to output - std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + if (seq_len == 0) { + for (int j = 0; j < item_size; ++j) { + out_data[j] = pad_value; + } + } else { + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + } out_data += item_size; } } @@ -170,7 +190,7 @@ template class FirstSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::LoDTensor& input, + const framework::LoDTensor& input, T pad_value, framework::Tensor* output) { // Create pointers to input and output data auto* in_data = input.data(); @@ -183,10 +203,16 @@ class FirstSeqPoolFunctor { for (int i = 0; i < seq_num; ++i) { // Calculate the length of each sequence int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - // Copy the first item of sequence to output - std::memcpy(out_data, in_data, item_size * sizeof(T)); - // Point to the next sequence - in_data += seq_len * item_size; + if (seq_len == 0) { + for (int j = 0; j < item_size; ++j) { + out_data[j] = pad_value; + } + } else { + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + } out_data += item_size; } } @@ -207,6 +233,7 @@ class SumSeqPoolGradFunctor { auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); + if (h == 0) continue; int64_t in_offset = lod[i] * in_w; const T* out_pos = out_g_data + i * out_w; T* in_pos = in_g_data + in_offset; @@ -222,27 +249,27 @@ class SequencePoolFunctor { public: /* max pool has index output */ void operator()(const platform::CPUDeviceContext& context, - const std::string pooltype, const framework::LoDTensor& input, - framework::Tensor* output, bool is_test, - framework::Tensor* index = nullptr) { + const std::string pooltype, T pad_value, + const framework::LoDTensor& input, framework::Tensor* output, + bool is_test, framework::Tensor* index = nullptr) { if (pooltype == "MAX") { if (is_test) { math::MaxSeqPoolFunctor max_pool; - max_pool(context, input, output, index); + max_pool(context, input, pad_value, output, index); } else { math::MaxSeqPoolFunctor max_pool; - max_pool(context, input, output, index); + max_pool(context, input, pad_value, output, index); } return; } if (pooltype == "LAST") { math::LastSeqPoolFunctor last_pool; - last_pool(context, input, output); + last_pool(context, input, pad_value, output); return; } if (pooltype == "FIRST") { math::FirstSeqPoolFunctor first_pool; - first_pool(context, input, output); + first_pool(context, input, pad_value, output); return; } @@ -260,7 +287,13 @@ class SequencePoolFunctor { .At(attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); - seqpool(src, dst, &attr); + if (attr.h == 0) { + for (int j = 0; j < attr.w; ++j) { + dst[j] = pad_value; + } + } else { + seqpool(src, dst, &attr); + } dst += attr.w; src += attr.h * attr.w; } @@ -268,11 +301,17 @@ class SequencePoolFunctor { } auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + Tensor out_t = output->Slice(i, i + 1); + int64_t w = input.numel() / input.dims()[0]; + if (lod[i] == lod[i + 1]) { + for (int j = 0; j < w; ++j) { + out_t.data()[j] = pad_value; + } + continue; + } Tensor in_t = input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); - Tensor out_t = output->Slice(i, i + 1); int64_t h = static_cast(lod[i + 1] - lod[i]); - int64_t w = input.numel() / input.dims()[0]; auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); if (pooltype == "AVERAGE") { @@ -316,6 +355,7 @@ class SequencePoolGradFunctor { auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + if (lod[i] == lod[i + 1]) continue; auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); auto out_g_t = out_grad.Slice(i, i + 1); diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index 51da6de26e2..4de99ba677d 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -24,96 +24,122 @@ namespace math { template struct MaxPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { T max_val = static_cast(-FLT_MAX); int max_index = -1; - for (int i = start; i < end; ++i) { - if (max_val < input[item_dim * i + tid]) { - max_val = input[item_dim * i + tid]; - max_index = i; + if (start == end) { + output[tid] = pad_value; + index[tid] = -1; + } else { + for (int i = start; i < end; ++i) { + if (max_val < input[item_dim * i + tid]) { + max_val = input[item_dim * i + tid]; + max_index = i; + } } + output[tid] = max_val; + index[tid] = max_index; } - output[tid] = max_val; - index[tid] = max_index; } } }; template struct AvgPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / static_cast(end - start); } - // end, start is lod, so end - start != 0 - output[tid] = val / static_cast(end - start); } } }; template struct SumPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + output[tid] = val; } - output[tid] = val; } } }; template struct SqrtPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / sqrt(end - start); } - // end, start is lod, so end - start != 0 - output[tid] = val / sqrt(end - start); } } }; template struct LastPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - output[tid] = input[item_dim * (end - 1) + tid]; + if (start == end) { + output[tid] = pad_value; + } else { + output[tid] = input[item_dim * (end - 1) + tid]; + } } } }; template struct FirstPoolFunctor { - HOSTDEVICE void operator()(const T* input, const size_t start, - const size_t end, const size_t item_dim, T* output, - int* index) { + HOSTDEVICE void operator()(const T* input, const T pad_value, + const size_t start, const size_t end, + const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - output[tid] = input[item_dim * start + tid]; + if (start == end) { + output[tid] = pad_value; + } else { + output[tid] = input[item_dim * start + tid]; + } } } }; template __global__ void sequence_pool_kernel(Range_OP op, const T* input, - const size_t* lod, const size_t lod_size, + const T pad_value, const size_t* lod, + const size_t lod_size, const size_t item_dim, T* output, int* index) { int bid = blockIdx.x; @@ -124,16 +150,17 @@ __global__ void sequence_pool_kernel(Range_OP op, const T* input, if (index != nullptr) { index_offset = &index[bid * item_dim]; } - op(input, start, end, item_dim, &output[bid * item_dim], index_offset); + op(input, pad_value, start, end, item_dim, &output[bid * item_dim], + index_offset); } template class SequencePoolFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const std::string pooltype, const framework::LoDTensor& input, - framework::Tensor* output, bool is_test, - framework::Tensor* index = nullptr) { + const std::string pooltype, T pad_value, + const framework::LoDTensor& input, framework::Tensor* output, + bool is_test, framework::Tensor* index = nullptr) { auto& lod = input.lod()[0]; const size_t item_dim = output->numel() / output->dims()[0]; dim3 threads(1024, 1); @@ -141,37 +168,37 @@ class SequencePoolFunctor { if (pooltype == "MAX") { sequence_pool_kernel< T, MaxPoolFunctor><<>>( - MaxPoolFunctor(), input.data(), + MaxPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_kernel< T, AvgPoolFunctor><<>>( - AvgPoolFunctor(), input.data(), + AvgPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SUM") { sequence_pool_kernel< T, SumPoolFunctor><<>>( - SumPoolFunctor(), input.data(), + SumPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SQRT") { sequence_pool_kernel< T, SqrtPoolFunctor><<>>( - SqrtPoolFunctor(), input.data(), + SqrtPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "LAST") { sequence_pool_kernel< T, LastPoolFunctor><<>>( - LastPoolFunctor(), input.data(), + LastPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "FIRST") { sequence_pool_kernel< T, FirstPoolFunctor><<>>( - FirstPoolFunctor(), input.data(), + FirstPoolFunctor(), input.data(), pad_value, lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else { diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h index a1046ea2160..1dc02eae201 100644 --- a/paddle/fluid/operators/math/sequence_pooling.h +++ b/paddle/fluid/operators/math/sequence_pooling.h @@ -27,8 +27,9 @@ class SequencePoolFunctor { public: /* max pool has index output */ void operator()(const DeviceContext& context, const std::string pooltype, - const framework::LoDTensor& input, framework::Tensor* output, - bool is_test = false, framework::Tensor* index = nullptr); + T pad_value, const framework::LoDTensor& input, + framework::Tensor* output, bool is_test = false, + framework::Tensor* index = nullptr); }; template diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index b4923571df9..f3193fdc556 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -57,6 +57,9 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.") .SetDefault("AVERAGE") .InEnum({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"}); + AddAttr("pad_value", + "(float, default 0.0) The value to pad for empty sequence.") + .SetDefault(0.0); AddComment(R"DOC( Sequence Pool Operator. @@ -69,6 +72,8 @@ It supports six pooling types: 5. FIRST: Out[i] = first instance in i-th sequence X[i] 6. MAX: $$Out[i] = max(X_i)$$ +and for the empty sequence Out[i] = attr(pad_value). + The following example explains how this works: For a mini-batch of 3 variable-length sentences, containing 2, 3, and 2 time-steps: diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h index f2e4a55dee4..c32734808c3 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h @@ -32,6 +32,7 @@ class SequencePoolKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); std::string pooltype = context.Attr("pooltype"); + T pad_value = static_cast(context.Attr("pad_value")); auto dims = in->dims(); auto lod = in->lod(); @@ -58,8 +59,8 @@ class SequencePoolKernel : public framework::OpKernel { index->mutable_data(context.GetPlace()); } math::SequencePoolFunctor pool; - pool(context.template device_context(), pooltype, *in, out, - is_test, index); + pool(context.template device_context(), pooltype, pad_value, + *in, out, is_test, index); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fc010521c43..b9f0dcc134b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2346,7 +2346,7 @@ def conv3d(input, return helper.append_activation(pre_act) -def sequence_pool(input, pool_type, is_test=False): +def sequence_pool(input, pool_type, is_test=False, pad_value=0.0): """ This function add the operator for sequence pooling. It pools features of all time-steps of each instance, and is applied @@ -2361,29 +2361,32 @@ def sequence_pool(input, pool_type, is_test=False): .. code-block:: text - x is a 1-level LoDTensor: - x.lod = [[2, 3, 2]] + x is a 1-level LoDTensor and **pad_value** = 0.0: + x.lod = [[2, 3, 2, 0]] x.data = [1, 3, 2, 4, 6, 5, 1] x.dims = [7, 1] then output is a Tensor: - out.dim = [3, 1] + out.dim = [4, 1] with condition len(x.lod[-1]) == out.dims[0] for different pool_type: - average: out.data = [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 - sum : out.data = [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 - sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), + average: out.data = [2, 4, 3, 0.0], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 + sum : out.data = [4, 12, 6, 0.0], where 4=1+3, 12=2+4+6, 6=5+1 + sqrt : out.data = [2.82, 6.93, 4.24, 0.0], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) - max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) - last : out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) - first : out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) + max : out.data = [3, 6, 5, 0.0], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) + last : out.data = [3, 6, 1, 0.0], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) + first : out.data = [1, 2, 5, 0.0], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) + + and all above 0.0 = **pad_value**. Args: - input(variable): The input variable which is a LoDTensor. + input (variable): The input variable which is a LoDTensor. pool_type (string): The pooling type of sequence_pool. It supports average, sum, sqrt and max. - is_test(bool, Default False): Used distinguish training from scoring mode. + is_test (bool): Used to distinguish training from scoring mode. Default False. + pad_value (float): Used to pad the pooling result for empty input sequence. Returns: The sequence pooling variable which is a Tensor. @@ -2392,6 +2395,8 @@ def sequence_pool(input, pool_type, is_test=False): .. code-block:: python + import paddle.fluid as fluid + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) avg_x = fluid.layers.sequence_pool(input=x, pool_type='average') @@ -2413,8 +2418,11 @@ def sequence_pool(input, pool_type, is_test=False): inputs={"X": input}, outputs={"Out": pool_out, "MaxIndex": max_index}, - attrs={"pooltype": pool_type.upper(), - "is_test": is_test}) + attrs={ + "pooltype": pool_type.upper(), + "is_test": is_test, + "pad_value": pad_value + }) # when pool_type is max, variable max_index is initialized, # so we stop the gradient explicitly here diff --git a/python/paddle/fluid/tests/unittests/test_seq_pool.py b/python/paddle/fluid/tests/unittests/test_seq_pool.py index 176265428c8..aa801b1f5d8 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_pool.py +++ b/python/paddle/fluid/tests/unittests/test_seq_pool.py @@ -20,31 +20,42 @@ from op_test import OpTest from test_reorder_lod_tensor import convert_to_offset -def compute_seqpool_sum(x, offset, out): +def compute_seqpool_sum(x, offset, out, pad_value=0.0): for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - out[i] = sub_x.sum(axis=0) + if offset[0][i] == offset[0][i + 1]: + out[i] = pad_value + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + out[i] = sub_x.sum(axis=0) -def compute_seqpool_avg(x, offset, out): +def compute_seqpool_avg(x, offset, out, pad_value=0.0): for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - out[i] = sub_x.mean(axis=0) + if offset[0][i] == offset[0][i + 1]: + out[i] = pad_value + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + out[i] = sub_x.mean(axis=0) -def compute_seqpool_sqrt(x, offset, out): +def compute_seqpool_sqrt(x, offset, out, pad_value=0.0): for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - seq_len = offset[0][i + 1] - offset[0][i] - out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len) + if offset[0][i] == offset[0][i + 1]: + out[i] = pad_value + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + seq_len = offset[0][i + 1] - offset[0][i] + out[i] = sub_x.sum(axis=0) / np.sqrt(seq_len) class TestSeqAvgPool(OpTest): + def set_lod(self): + return [[11]] + def set_data(self): self.op_type = 'sequence_pool' - # one level, batch size is 4 x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') - lod = [[11]] + lod = self.set_lod() self.inputs = {'X': (x, lod)} offset = convert_to_offset(lod) out = np.zeros((len(lod[0]), 23)).astype('float32') @@ -52,8 +63,8 @@ class TestSeqAvgPool(OpTest): return x, offset, out def compute(self, x, offset, out): - self.attrs = {'pooltype': "AVERAGE"} - compute_seqpool_avg(x, offset, out) + self.attrs = {"pad_value": 0.0, 'pooltype': "AVERAGE"} + compute_seqpool_avg(x, offset, out, self.attrs["pad_value"]) def setUp(self): x, offset, out = self.set_data() @@ -69,95 +80,160 @@ class TestSeqAvgPool(OpTest): self.check_grad(["X"], "Out") +class TestSeqAvgPoolLen0(TestSeqAvgPool): + def set_lod(self): + return [[0, 4, 0, 7, 0]] + + class TestSeqSumPool(TestSeqAvgPool): def compute(self, x, offset, out): - self.attrs = {'pooltype': "SUM"} - compute_seqpool_sum(x, offset, out) + self.attrs = {"pad_value": 0.1, 'pooltype': "SUM"} + compute_seqpool_sum(x, offset, out, self.attrs["pad_value"]) + + +class TestSeqSumPoolLen0(TestSeqSumPool): + def set_lod(self): + return [[0, 4, 0, 7, 0]] class TestSeqMaxPool(TestSeqAvgPool): + def set_lod(self): + return [[13]] + def set_data(self): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [13, 23]).astype('float32') - lod = [[13]] + lod = self.set_lod() offset = convert_to_offset(lod) for i in range(len(offset[0]) - 1): l = offset[0][i + 1] - offset[0][i] - x[offset[0][i] + np.random.randint(l), :] += 2.0 + if l > 0: + x[offset[0][i] + np.random.randint(l), :] += 2.0 self.inputs = {'X': (x, lod)} - out = np.zeros((1, 23)).astype('float32') + out = np.zeros((len(lod[0]), 23)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): - self.attrs = {'pooltype': "MAX"} + self.attrs = {"pad_value": 0.5, 'pooltype': "MAX"} for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - out[i] = np.amax(sub_x, axis=0) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + out[i] = np.amax(sub_x, axis=0) + + +class TestSeqMaxPoolLen0(TestSeqMaxPool): + def set_lod(self): + return [[0, 1, 1, 5, 6, 0]] class TestSeqSqrtPool(TestSeqAvgPool): def compute(self, x, offset, out): - self.attrs = {'pooltype': "SQRT"} - compute_seqpool_sqrt(x, offset, out) + self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"} + compute_seqpool_sqrt(x, offset, out, self.attrs["pad_value"]) + + +class TestSeqSqrtPoolLen0(TestSeqSqrtPool): + def set_lod(self): + return [[0, 7, 0, 2, 2, 0]] class TestSeqLastPool(TestSeqAvgPool): def compute(self, x, offset, out): - self.attrs = {'pooltype': "LAST"} + self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"} for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - out[i] = sub_x[-1, :] + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + out[i] = sub_x[-1, :] + + +class TestSeqLastPoolLen0(TestSeqLastPool): + def set_lod(self): + return [[0, 3, 4, 0, 4, 0]] class TestSeqFirstPool(TestSeqAvgPool): def compute(self, x, offset, out): - self.attrs = {'pooltype': "FIRST"} + self.attrs = {"pad_value": 0.3, 'pooltype': "FIRST"} for i in range(len(offset[0]) - 1): - sub_x = x[offset[0][i]:offset[0][i + 1], :] - out[i] = sub_x[0, :] + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] + else: + sub_x = x[offset[0][i]:offset[0][i + 1], :] + out[i] = sub_x[0, :] + + +class TestSeqFirstPoolLen0(TestSeqFirstPool): + def set_lod(self): + return [[0, 2, 0, 3, 6, 0]] class TestSeqAvgPool2D(TestSeqAvgPool): + def set_lod(self): + return [[4, 1, 3, 5]] + def set_data(self): self.op_type = 'sequence_pool' - # one level, batch size is 4 x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') - lod = [[4, 1, 3, 5]] + lod = self.set_lod() self.inputs = {'X': (x, lod)} offset = convert_to_offset(lod) - out = np.zeros((4, 3, 17)).astype('float32') + out = np.zeros((len(lod[0]), 3, 17)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): - self.attrs = {'pooltype': "AVERAGE"} + self.attrs = {"pad_value": 0.0, 'pooltype': "AVERAGE"} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 17)) - out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 17)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 17)) + out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) + + +class TestSeqAvgPool2DLen0(TestSeqAvgPool2D): + def set_lod(self): + return [[0, 5, 0, 8, 0]] class TestSeqSumPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): - self.attrs = {'pooltype': "SUM"} + self.attrs = {"pad_value": 0.2, 'pooltype': "SUM"} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 17)) - out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 17)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 17)) + out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) + + +class TestSeqSumPool2DLen0(TestSeqSumPool2D): + def set_lod(self): + return [[0, 8, 0, 5, 0]] class TestSeqSqrtPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): - self.attrs = {'pooltype': "SQRT"} + self.attrs = {"pad_value": 0.0, 'pooltype': "SQRT"} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 17)) - seq_len = offset[0][i + 1] - offset[0][i] - out[i] = np.reshape(sub_x.sum(axis=0) / np.sqrt(seq_len), (3, 17)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 17)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 17)) + seq_len = offset[0][i + 1] - offset[0][i] + out[i] = np.reshape( + sub_x.sum(axis=0) / np.sqrt(seq_len), (3, 17)) def test_check_grad(self): # Remove MaxIndex after check_grad is refined. @@ -166,36 +242,57 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): self.check_grad(["X"], "Out", max_relative_error=0.06) +class TestSeqSqrtPool2DLen0(TestSeqSqrtPool2D): + def set_lod(self): + return [[0, 8, 0, 5, 0]] + + class TestSeqMaxPool2D(TestSeqAvgPool2D): + def set_lod(self): + return [[4, 1, 3, 5]] + def set_data(self): self.op_type = 'sequence_pool' x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32') - lod = [[4, 1, 3, 5]] - self.inputs = {'X': (x, lod)} - offset = convert_to_offset(lod) + self.lod = self.set_lod() + self.inputs = {'X': (x, self.lod)} + offset = convert_to_offset(self.lod) for i in range(len(offset[0]) - 1): l = offset[0][i + 1] - offset[0][i] + if l == 0: + continue x[offset[0][i] + np.random.randint(l), :] += 1.0 - out = np.zeros((4, 3, 11)).astype('float32') + out = np.zeros((len(self.lod[0]), 3, 11)).astype('float32') self.outputs = {'Out': out} return x, offset, out def compute(self, x, offset, out): - self.attrs = {'pooltype': "MAX"} + self.attrs = {"pad_value": 0.0, 'pooltype': "MAX"} for i in range(len(offset[0]) - 1): + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 11)) + continue sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], (-1, 3 * 11)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) +class TestSeqMaxPool2DLen0(TestSeqMaxPool2D): + def set_lod(self): + return [[0, 3, 0, 10, 0]] + + class TestSeqMaxPool2DInference(TestSeqMaxPool2D): def compute(self, x, offset, out): - self.attrs = {'pooltype': "MAX", 'is_test': True} + self.attrs = {"pad_value": 1.0, 'pooltype': "MAX", 'is_test': True} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 11)) - out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 11)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 11)) + out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) def test_check_grad(self): """Grad computation does not apply to Sequence MAX @@ -203,22 +300,43 @@ class TestSeqMaxPool2DInference(TestSeqMaxPool2D): return +class TestSeqMaxPool2DInferenceLen0(TestSeqMaxPool2DInference): + def set_lod(self): + return [[0, 3, 0, 10, 0]] + + class TestSeqLastPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): - self.attrs = {'pooltype': "LAST"} + self.attrs = {"pad_value": 0.0, 'pooltype': "LAST"} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 17)) - out[i] = np.reshape(sub_x[-1, :], (3, 17)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 17)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 17)) + out[i] = np.reshape(sub_x[-1, :], (3, 17)) + + +class TestSeqLastPool2DLen0(TestSeqLastPool2D): + def set_lod(self): + return [[0, 3, 0, 1, 9, 0]] class TestSeqFirstPool2D(TestSeqAvgPool2D): def compute(self, x, offset, out): - self.attrs = {'pooltype': "FIRST"} + self.attrs = {"pad_value": 0.0, 'pooltype': "FIRST"} for i in range(len(offset[0]) - 1): - sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], - (-1, 3 * 17)) - out[i] = np.reshape(sub_x[0, :], (3, 17)) + if offset[0][i] == offset[0][i + 1]: + out[i] = self.attrs["pad_value"] * np.ones((3, 17)) + else: + sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :], + (-1, 3 * 17)) + out[i] = np.reshape(sub_x[0, :], (3, 17)) + + +class TestSeqFirstPool2DLen0(TestSeqFirstPool2D): + def set_lod(self): + return [[0, 3, 0, 3, 7, 0]] if __name__ == '__main__': -- GitLab