From 60f706a1d6f497088f1957354910176e649059e8 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 10 Oct 2017 19:04:29 +0800 Subject: [PATCH] add SQRT strategy for sequence_pool_op --- paddle/operators/sequence_pool_op.cc | 14 +++++------ paddle/operators/sequence_pool_op.h | 8 ++++++ .../v2/framework/tests/test_seq_pool.py | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 06c00d31e..9b8d86b40 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -36,11 +36,9 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "A float LoDTensor, the variable-length input of SequencePoolOp"); - AddOutput( - "Out", - "A float LoDTensor, the variable-length output of SequencePoolOp."); + AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); + AddOutput("Out", + "A LoDTensor, the variable-length output of SequencePoolOp."); AddAttr( "strategy", "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") @@ -49,13 +47,13 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( SequencePoolOp pools features of all time-steps of each instance. - For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 time-steps: + For a mini-batch of 3 variable-length sentences, containing 2, 3, and 2 time-steps: - Assume X is a [7,M,N] float LoDTensor, and X->lod()[0] = [0, 2, 5, 7]. + Assume X is a [7,M,N] LoDTensor, and X->lod()[0] = [0, 2, 5, 7], 7=2+3+2. Besides, for the sake of simplicity, we assume M=1 and N=1, and the value of X = [[1, 3], [2, 4, 6], [5, 1]]. - Thus, Out is a [3,1,1] float LoDTensor, but Out->lod() is nullptr. + Thus, Out is a [3,1,1] LoDTensor, but Out->lod() is nullptr. And for different strategy, the value of Out is as follows: - AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 752d71412..fd056b71c 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -77,6 +77,10 @@ class SequencePoolKernel : public framework::OpKernel { case SUM: out_e.device(place) = in_e.sum(Eigen::array({{0}})); break; + case SQRT: + out_e.device(place) = in_e.sum(Eigen::array({{0}})) / + std::sqrt(static_cast(h)); + break; default: PADDLE_THROW("unsupported pooling strategy"); } @@ -115,6 +119,10 @@ class SequencePoolGradKernel : public framework::OpKernel { case SUM: in_g_e.device(place) = (out_g_e).broadcast(bcast); break; + case SQRT: + in_g_e.device(place) = + (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); + break; default: PADDLE_THROW("unsupported pooling strategy"); } diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index 211086e5f..fbcf6dac9 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -82,5 +82,30 @@ class TestSeqSumPool2D(TestSeqAvgPool2D): out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) +class TestSeqSqrtPool(TestSeqAvgPool): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SQRT} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + len = lod[0][i + 1] - lod[0][i] + out[i] = sub_x.sum(axis=0) / np.sqrt(len) + + +class TestSeqSqrtPool2D(TestSeqAvgPool2D): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SQRT} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + len = lod[0][i + 1] - lod[0][i] + out[i] = np.reshape(sub_x.sum(axis=0) / np.sqrt(len), (3, 17)) + + def test_check_grad(self): + self.check_grad(["X"], "Out", max_relative_error=0.06) + + if __name__ == '__main__': unittest.main() -- GitLab