From 393c748c89049a7d9b8991266eeec09558395cc5 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Thu, 12 Oct 2017 19:35:46 +0800 Subject: [PATCH] add seqlastin/seqfirstin for seq_pool op --- paddle/operators/sequence_pool_op.h | 17 ++++++++ .../v2/framework/tests/test_seq_pool.py | 40 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index fd056b71cf1..8bfb80c33fe 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -81,6 +82,12 @@ class SequencePoolKernel : public framework::OpKernel { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); break; + case LAST: + out_e.device(place) = in_e.chip(h - 1, 0); + break; + case FIRST: + out_e.device(place) = in_e.chip(0, 0); + break; default: PADDLE_THROW("unsupported pooling strategy"); } @@ -102,6 +109,10 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); + if (strategy > 2) { + // set X@Grad be zero at first when strategy is LAST/FIRST/MAX + math::SetConstant(context.device_context(), in_g, 0); + } auto place = context.GetEigenDevice(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_g->Slice(static_cast(lod[i]), @@ -123,6 +134,12 @@ class SequencePoolGradKernel : public framework::OpKernel { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); break; + case LAST: + in_g_e.chip(h - 1, 0).device(place) = out_g_e; + break; + case FIRST: + in_g_e.chip(0, 0).device(place) = out_g_e; + break; default: PADDLE_THROW("unsupported pooling strategy"); } diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index fbcf6dac93f..0ebf78bf8f0 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -107,5 +107,45 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): self.check_grad(["X"], "Out", max_relative_error=0.06) +class TestSeqLastPool(TestSeqAvgPool): + def compute(self): + self.attrs = {'strategy': SeqPoolType.LAST} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x[-1, :] + + +class TestSeqLastPool2D(TestSeqAvgPool2D): + def compute(self): + self.attrs = {'strategy': SeqPoolType.LAST} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x[-1, :], (3, 17)) + + +class TestSeqFirstPool(TestSeqAvgPool): + def compute(self): + self.attrs = {'strategy': SeqPoolType.FIRST} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x[0, :] + + +class TestSeqFirstPool2D(TestSeqAvgPool2D): + def compute(self): + self.attrs = {'strategy': SeqPoolType.FIRST} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x[0, :], (3, 17)) + + if __name__ == '__main__': unittest.main() -- GitLab