From 8425c2c859b22f263e213d4fed454890b598948c Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 29 Mar 2018 16:34:33 +0800 Subject: [PATCH] Speed/sequence op1 (#9217) * "add functors" * "remove old code" * "fix" * "fix ci" * "add details" * "fix ci" * "fix ci" * "fix ci" * "fix ci" * "remove unused code" --- .../fluid/operators/math/sequence_pooling.cc | 112 ++++- .../fluid/operators/math/sequence_pooling.cu | 381 ++++++++++++++---- .../fluid/operators/math/sequence_pooling.h | 20 +- paddle/fluid/operators/sequence_pool_op.h | 102 +---- .../fluid/tests/unittests/test_seq_pool.py | 110 ++--- 5 files changed, 484 insertions(+), 241 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f7a6f2bdf..5ae42ab97 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -19,8 +19,17 @@ namespace paddle { namespace operators { namespace math { +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + template -class MaxSeqPoolFunctor { +class MaxSeqPoolFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& input, framework::Tensor* output, @@ -60,7 +69,7 @@ class MaxSeqPoolFunctor { }; template -class MaxSeqPoolGradFunctor { +class MaxSeqPoolGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& out_grad, @@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor { } }; -template class MaxSeqPoolFunctor; -template class MaxSeqPoolFunctor; -template class MaxSeqPoolGradFunctor; -template class MaxSeqPoolGradFunctor; +template +class SequencePoolFunctor { + public: + /* max pool has index output */ + void operator()(const platform::CPUDeviceContext& context, + const std::string pooltype, const framework::LoDTensor& input, + framework::Tensor* output, + framework::Tensor* index = nullptr) { + if (pooltype == "MAX") { + math::MaxSeqPoolFunctor max_pool; + max_pool(context, input, output, index); + return; + } + auto lod = input.lod()[0]; + auto& place = *context.eigen_device(); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + Tensor in_t = + input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + Tensor out_t = output->Slice(i, i + 1); + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t w = input.numel() / input.dims()[0]; + auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); + auto out_e = EigenVector::Flatten(out_t); + if (pooltype == "AVERAGE") { + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + } else if (pooltype == "SUM") { + out_e.device(place) = in_e.sum(Eigen::array({{0}})); + } else if (pooltype == "SQRT") { + out_e.device(place) = in_e.sum(Eigen::array({{0}})) / + std::sqrt(static_cast(h)); + } else if (pooltype == "LAST") { + out_e.device(place) = in_e.chip(h - 1, 0); + } else if (pooltype == "FIRST") { + out_e.device(place) = in_e.chip(0, 0); + } else { + PADDLE_THROW("unsupported pooling pooltype"); + } + } + } +}; + +template +class SequencePoolGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const std::string pooltype, const framework::Tensor& out_grad, + framework::LoDTensor* in_grad, + /* max pool has index */ + const framework::Tensor* index = nullptr) { + if (pooltype == "MAX") { + math::MaxSeqPoolGradFunctor max_pool_grad; + max_pool_grad(context, out_grad, *index, in_grad); + return; + } + + if (pooltype == "LAST" || pooltype == "FIRST") { + // set X@Grad be zero at first when pooltype is LAST/FIRST + math::SetConstant functor; + functor(context, in_grad, 0); + } + auto lod = in_grad->lod()[0]; + auto& place = *context.eigen_device(); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + auto in_g_t = in_grad->Slice(static_cast(lod[i]), + static_cast(lod[i + 1])); + auto out_g_t = out_grad.Slice(i, i + 1); + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t w = in_grad->numel() / in_grad->dims()[0]; + auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); + auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + auto out_g_e_v = EigenVector::Flatten(out_g_t); + Eigen::DSizes bcast(h, 1); + + if (pooltype == "AVERAGE") { + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + } else if (pooltype == "SUM") { + in_g_e.device(place) = (out_g_e).broadcast(bcast); + } else if (pooltype == "SQRT") { + in_g_e.device(place) = + (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); + } else if (pooltype == "LAST") { + in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; + } else if (pooltype == "FIRST") { + in_g_e.chip(0, 0).device(place) = out_g_e_v; + } else { + PADDLE_THROW("unsupported pooling pooltype"); + } + } + } +}; + +template class SequencePoolFunctor; +template class SequencePoolFunctor; +template class SequencePoolGradFunctor; +template class SequencePoolGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu index d61407c02..1935364da 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ b/paddle/fluid/operators/math/sequence_pooling.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" +#include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -22,113 +23,331 @@ namespace math { #define FLT_MAX __FLT_MAX__ template -__global__ void KeMaxSequencePool(const T* input, const size_t* starts, - T* output, int* index, int64_t num_seq, - int64_t dim) { - int dim_idx = threadIdx.x; - int seq_id = blockIdx.x; - if (seq_id >= num_seq) return; - size_t start = starts[seq_id]; - size_t end = starts[seq_id + 1]; - - for (int64_t i = dim_idx; i < dim; i += blockDim.x) { - T max_val = static_cast(-FLT_MAX); - int max_id = -1; - for (size_t step_id = start; step_id < end; step_id++) { - if (max_val < input[step_id * dim + i]) { - max_val = input[step_id * dim + i]; - max_id = step_id; +struct MaxPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + T max_val = static_cast(-FLT_MAX); + int max_index = -1; + for (int i = start; i < end; ++i) { + if (max_val < input[item_dim * i + tid]) { + max_val = input[item_dim * i + tid]; + max_index = i; + } } + output[tid] = max_val; + index[tid] = max_index; } - output[seq_id * dim + i] = max_val; - index[seq_id * dim + i] = max_id; } -} +}; template -class MaxSeqPoolFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::LoDTensor& input, framework::Tensor* output, - framework::Tensor* index) { - auto in_dims = input.dims(); - auto out_dims = output->dims(); - auto idx_dims = index->dims(); - PADDLE_ENFORCE_GT(in_dims.size(), static_cast(1)); - PADDLE_ENFORCE_GT(out_dims.size(), 1); - for (int64_t i = 1; i < in_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); +struct AvgPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / static_cast(end - start); } - PADDLE_ENFORCE_EQ(idx_dims, out_dims); + } +}; - auto starts = input.lod()[0]; - const T* in_data = input.data(); - T* out_data = output->data(); - int* max_index = index->data(); +template +struct SumPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + output[tid] = val; + } + } +}; - int64_t num_seq = out_dims[0]; - int64_t dim = output->numel() / num_seq; +template +struct SqrtPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / sqrt(end - start); + } + } +}; - dim3 threads(256, 1); - dim3 grid(num_seq, 1); - auto stream = context.stream(); - KeMaxSequencePool<<>>( - in_data, starts.CUDAData(context.GetPlace()), out_data, max_index, - num_seq, dim); +template +struct LastPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + output[tid] = input[item_dim * (end - 1) + tid]; + } } }; template -__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index, - T* in_grad, int64_t num_seq, - int64_t dim) { - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int col_idx = idx % dim; - if (idx < num_seq * dim) { - int step_id = max_index[idx]; - in_grad[step_id * dim + col_idx] = out_grad[idx]; +struct FirstPoolFunctor { + HOSTDEVICE void operator()(const T* input, const size_t start, + const size_t end, const size_t item_dim, T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + output[tid] = input[item_dim * start + tid]; + } } +}; + +template +__global__ void sequence_pool_kernel(Range_OP op, const T* input, + const size_t* lod, const size_t lod_size, + const size_t item_dim, T* output, + int* index) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + size_t start = lod[bid]; + size_t end = lod[bid + 1]; + int* index_offset = nullptr; + if (index != nullptr) { + index_offset = &index[bid * item_dim]; + } + op(input, start, end, item_dim, &output[bid * item_dim], index_offset); } template -class MaxSeqPoolGradFunctor { +class SequencePoolFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& out_grad, - const framework::Tensor& index, - framework::LoDTensor* in_grad) { - auto og_dims = out_grad.dims(); - auto idx_dims = index.dims(); - auto ig_dims = in_grad->dims(); - PADDLE_ENFORCE_GT(og_dims.size(), static_cast(1)); - PADDLE_ENFORCE_GT(ig_dims.size(), static_cast(1)); - for (int64_t i = 1; i < og_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]); + const std::string pooltype, const framework::LoDTensor& input, + framework::Tensor* output, + framework::Tensor* index = nullptr) { + auto lod = input.lod()[0]; + const size_t item_dim = output->numel() / output->dims()[0]; + dim3 threads(1024, 1); + dim3 grid(lod.size(), 1); + if (pooltype == "MAX") { + sequence_pool_kernel< + T, MaxPoolFunctor><<>>( + MaxPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), index->data()); + } else if (pooltype == "AVERAGE") { + sequence_pool_kernel< + T, AvgPoolFunctor><<>>( + AvgPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "SUM") { + sequence_pool_kernel< + T, SumPoolFunctor><<>>( + SumPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "SQRT") { + sequence_pool_kernel< + T, SqrtPoolFunctor><<>>( + SqrtPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "LAST") { + sequence_pool_kernel< + T, LastPoolFunctor><<>>( + LastPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "FIRST") { + sequence_pool_kernel< + T, FirstPoolFunctor><<>>( + FirstPoolFunctor(), input.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + output->mutable_data(context.GetPlace()), nullptr); + } else { + PADDLE_THROW("unsupported pooling pooltype"); } - PADDLE_ENFORCE_EQ(idx_dims, og_dims); + } +}; - const T* og_data = out_grad.data(); - const int* max_index = index.data(); - T* ig_data = in_grad->data(); +template +struct MaxPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == index[tid]) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; - SetConstant set_zero; - set_zero(context, in_grad, static_cast(0.0)); - int64_t num_seq = og_dims[0]; - int64_t dim = out_grad.numel() / num_seq; +template +struct AvgPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = out_grad[tid] / (end - start); + } + } + } +}; - unsigned int blocks = (num_seq * dim + 128 - 1) / 128; - dim3 threads(128, 1); - dim3 grid(blocks, 1); - auto stream = context.stream(); - KeMaxSequencePoolGrad<<>>( - og_data, max_index, ig_data, num_seq, dim); +template +struct SumPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } + } + } +}; + +template +struct SqrtPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = + out_grad[tid] / (sqrt(static_cast(end - start))); + } + } + } +}; + +template +struct LastPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == end - 1) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; + +template +struct FirstPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, const size_t start, + const size_t end, const size_t item_dim, + T* in_grad, const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == start) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; + +template +__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad, + const size_t* lod, + const size_t lod_size, + const size_t item_dim, T* in_grad, + const int* index) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + size_t start = lod[bid]; + size_t end = lod[bid + 1]; + const int* index_offset = nullptr; + if (index != nullptr) { + index_offset = &index[bid * item_dim]; + } + op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset); +} + +template +class SequencePoolGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const std::string pooltype, const framework::Tensor& out_grad, + framework::LoDTensor* in_grad, + /* max pool has index */ + const framework::Tensor* index = nullptr) { + auto lod = in_grad->lod()[0]; + const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; + dim3 threads(1024, 1); + dim3 grid(lod.size(), 1); + if (pooltype == "MAX") { + sequence_pool_grad_kernel< + T, MaxPoolGradFunctor><<>>( + MaxPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), index->data()); + } else if (pooltype == "AVERAGE") { + sequence_pool_grad_kernel< + T, AvgPoolGradFunctor><<>>( + AvgPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "SUM") { + sequence_pool_grad_kernel< + T, SumPoolGradFunctor><<>>( + SumPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "SQRT") { + sequence_pool_grad_kernel< + T, SqrtPoolGradFunctor><<>>( + SqrtPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "LAST") { + sequence_pool_grad_kernel< + T, LastPoolGradFunctor><<>>( + LastPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), nullptr); + } else if (pooltype == "FIRST") { + sequence_pool_grad_kernel< + T, FirstPoolGradFunctor><<>>( + FirstPoolGradFunctor(), out_grad.data(), + lod.CUDAData(context.GetPlace()), lod.size(), item_dim, + in_grad->mutable_data(context.GetPlace()), nullptr); + + } else { + PADDLE_THROW("unsupported pooling pooltype"); + } } }; -template class MaxSeqPoolFunctor; -template class MaxSeqPoolFunctor; -template class MaxSeqPoolGradFunctor; -template class MaxSeqPoolGradFunctor; +// sequence pooling +template class SequencePoolFunctor; +template class SequencePoolFunctor; +template class SequencePoolGradFunctor; +template class SequencePoolGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h index ecb76884f..38e780222 100644 --- a/paddle/fluid/operators/math/sequence_pooling.h +++ b/paddle/fluid/operators/math/sequence_pooling.h @@ -21,23 +21,23 @@ namespace paddle { namespace operators { namespace math { -#define FLT_MAX __FLT_MAX__ - template -class MaxSeqPoolFunctor { +class SequencePoolFunctor { public: - void operator()(const DeviceContext& context, + /* max pool has index output */ + void operator()(const DeviceContext& context, const std::string pooltype, const framework::LoDTensor& input, framework::Tensor* output, - framework::Tensor* index); + framework::Tensor* index = nullptr); }; -template -class MaxSeqPoolGradFunctor { +template +class SequencePoolGradFunctor { public: - void operator()(const DeviceContext& context, + void operator()(const DeviceContext& context, const std::string pooltype, const framework::Tensor& out_grad, - const framework::Tensor& index, - framework::LoDTensor* in_grad); + framework::LoDTensor* in_grad, + /* max pool has index */ + const framework::Tensor* index = nullptr); }; } // namespace math diff --git a/paddle/fluid/operators/sequence_pool_op.h b/paddle/fluid/operators/sequence_pool_op.h index 8706ff14a..c58d677c9 100644 --- a/paddle/fluid/operators/sequence_pool_op.h +++ b/paddle/fluid/operators/sequence_pool_op.h @@ -23,12 +23,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenVector = framework::EigenVector; -template -using EigenMatrix = framework::EigenMatrix; template class SequencePoolKernel : public framework::OpKernel { @@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); std::string pooltype = context.Attr("pooltype"); + Tensor* index = nullptr; + if (pooltype == "MAX") { + index = context.Output("MaxIndex"); + } auto dims = in->dims(); auto lod = in->lod(); - int64_t w = in->numel() / dims[0]; - // InferShape by lod PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_GE( @@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel { "The first dimension of Input(X) must be large than batch size."); dims[0] = lod[0].size() - 1; out->Resize({dims}); - - auto lod_level_0 = lod[0]; - out->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); if (pooltype == "MAX") { - math::MaxSeqPoolFunctor max_pool; - auto* index = context.Output("MaxIndex"); index->Resize({dims}); index->mutable_data(context.GetPlace()); - max_pool(dev_ctx, *in, out, index); - return; - } - - auto& place = - *context.template device_context().eigen_device(); - for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor in_t = in->Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); - Tensor out_t = out->Slice(i, i + 1); - int64_t h = static_cast(lod_level_0[i + 1] - lod_level_0[i]); - auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); - auto out_e = EigenVector::Flatten(out_t); - - if (pooltype == "AVERAGE") { - out_e.device(place) = in_e.mean(Eigen::array({{0}})); - } else if (pooltype == "SUM") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})); - } else if (pooltype == "SQRT") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})) / - std::sqrt(static_cast(h)); - } else if (pooltype == "LAST") { - out_e.device(place) = in_e.chip(h - 1, 0); - } else if (pooltype == "FIRST") { - out_e.device(place) = in_e.chip(0, 0); - } else { - PADDLE_THROW("unsupported pooling pooltype"); - } } + math::SequencePoolFunctor pool; + pool(context.template device_context(), pooltype, *in, out, + index); } }; @@ -96,58 +61,17 @@ template class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); std::string pooltype = context.Attr("pooltype"); - - auto dims = in->dims(); - auto lod = in->lod()[0]; - int64_t w = in->numel() / dims[0]; - - in_g->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - + const Tensor* index = nullptr; if (pooltype == "MAX") { - math::MaxSeqPoolGradFunctor max_pool_grad; - auto* index = context.Input("MaxIndex"); - max_pool_grad(dev_ctx, *out_g, *index, in_g); - return; - } - - if (pooltype == "LAST" || pooltype == "FIRST") { - // set X@Grad be zero at first when pooltype is LAST/FIRST - math::SetConstant functor; - functor(dev_ctx, in_g, 0); - } - auto& place = - *context.template device_context().eigen_device(); - - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - auto in_g_t = - in_g->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); - auto out_g_t = out_g->Slice(i, i + 1); - int64_t h = static_cast(lod[i + 1] - lod[i]); - auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); - auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); - auto out_g_e_v = EigenVector::Flatten(out_g_t); - Eigen::DSizes bcast(h, 1); - - if (pooltype == "AVERAGE") { - in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); - } else if (pooltype == "SUM") { - in_g_e.device(place) = (out_g_e).broadcast(bcast); - } else if (pooltype == "SQRT") { - in_g_e.device(place) = - (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); - } else if (pooltype == "LAST") { - in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; - } else if (pooltype == "FIRST") { - in_g_e.chip(0, 0).device(place) = out_g_e_v; - } else { - PADDLE_THROW("unsupported pooling pooltype"); - } + index = context.Input("MaxIndex"); } + in_g->mutable_data(context.GetPlace()); + math::SequencePoolGradFunctor pool; + pool(context.template device_context(), pooltype, *out_g, + in_g, index); } }; diff --git a/python/paddle/fluid/tests/unittests/test_seq_pool.py b/python/paddle/fluid/tests/unittests/test_seq_pool.py index 048847572..2e48ef0e8 100644 --- a/python/paddle/fluid/tests/unittests/test_seq_pool.py +++ b/python/paddle/fluid/tests/unittests/test_seq_pool.py @@ -49,6 +49,61 @@ class TestSeqAvgPool(OpTest): self.check_grad(["X"], "Out") +class TestSeqSumPool(TestSeqAvgPool): + def compute(self, x, lod, out): + self.attrs = {'pooltype': "SUM"} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x.sum(axis=0) + + +class TestSeqMaxPool(TestSeqAvgPool): + def set_data(self): + self.op_type = 'sequence_pool' + x = np.random.uniform(0.1, 1, [13, 23]).astype('float32') + lod = [[0, 4, 5, 8, 13]] + for i in range(4): + l = lod[0][i + 1] - lod[0][i] + x[lod[0][i] + np.random.randint(l), :] += 2.0 + + self.inputs = {'X': (x, lod)} + + out = np.zeros((4, 23)).astype('float32') + self.outputs = {'Out': out} + return x, lod, out + + def compute(self, x, lod, out): + self.attrs = {'pooltype': "MAX"} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = np.amax(sub_x, axis=0) + + +class TestSeqSqrtPool(TestSeqAvgPool): + def compute(self, x, lod, out): + self.attrs = {'pooltype': "SQRT"} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + len = lod[0][i + 1] - lod[0][i] + out[i] = sub_x.sum(axis=0) / np.sqrt(len) + + +class TestSeqLastPool(TestSeqAvgPool): + def compute(self, x, lod, out): + self.attrs = {'pooltype': "LAST"} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x[-1, :] + + +class TestSeqFirstPool(TestSeqAvgPool): + def compute(self, x, lod, out): + self.attrs = {'pooltype': "FIRST"} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x[0, :] + + class TestSeqAvgPool2D(TestSeqAvgPool): def set_data(self): self.op_type = 'sequence_pool' @@ -68,14 +123,6 @@ class TestSeqAvgPool2D(TestSeqAvgPool): out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) -class TestSeqSumPool(TestSeqAvgPool): - def compute(self, x, lod, out): - self.attrs = {'pooltype': "SUM"} - for i in range(4): - sub_x = x[lod[0][i]:lod[0][i + 1], :] - out[i] = sub_x.sum(axis=0) - - class TestSeqSumPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): self.attrs = {'pooltype': "SUM"} @@ -84,15 +131,6 @@ class TestSeqSumPool2D(TestSeqAvgPool2D): out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) -class TestSeqSqrtPool(TestSeqAvgPool): - def compute(self, x, lod, out): - self.attrs = {'pooltype': "SQRT"} - for i in range(4): - sub_x = x[lod[0][i]:lod[0][i + 1], :] - len = lod[0][i + 1] - lod[0][i] - out[i] = sub_x.sum(axis=0) / np.sqrt(len) - - class TestSeqSqrtPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): self.attrs = {'pooltype': "SQRT"} @@ -108,28 +146,6 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): self.check_grad(["X"], "Out", max_relative_error=0.06) -class TestSeqMaxPool(TestSeqAvgPool): - def set_data(self): - self.op_type = 'sequence_pool' - x = np.random.uniform(0.1, 1, [13, 23]).astype('float32') - lod = [[0, 4, 5, 8, 13]] - for i in range(4): - l = lod[0][i + 1] - lod[0][i] - x[lod[0][i] + np.random.randint(l), :] += 2.0 - - self.inputs = {'X': (x, lod)} - - out = np.zeros((4, 23)).astype('float32') - self.outputs = {'Out': out} - return x, lod, out - - def compute(self, x, lod, out): - self.attrs = {'pooltype': "MAX"} - for i in range(4): - sub_x = x[lod[0][i]:lod[0][i + 1], :] - out[i] = np.amax(sub_x, axis=0) - - class TestSeqMaxPool2D(TestSeqAvgPool2D): def set_data(self): self.op_type = 'sequence_pool' @@ -151,14 +167,6 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D): out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) -class TestSeqLastPool(TestSeqAvgPool): - def compute(self, x, lod, out): - self.attrs = {'pooltype': "LAST"} - for i in range(4): - sub_x = x[lod[0][i]:lod[0][i + 1], :] - out[i] = sub_x[-1, :] - - class TestSeqLastPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): self.attrs = {'pooltype': "LAST"} @@ -167,14 +175,6 @@ class TestSeqLastPool2D(TestSeqAvgPool2D): out[i] = np.reshape(sub_x[-1, :], (3, 17)) -class TestSeqFirstPool(TestSeqAvgPool): - def compute(self, x, lod, out): - self.attrs = {'pooltype': "FIRST"} - for i in range(4): - sub_x = x[lod[0][i]:lod[0][i + 1], :] - out[i] = sub_x[0, :] - - class TestSeqFirstPool2D(TestSeqAvgPool2D): def compute(self, x, lod, out): self.attrs = {'pooltype': "FIRST"} -- GitLab