From 1b01f1ea7b9a6c82beb9776e4a847ea127edfa4e Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 19 Sep 2017 17:31:46 +0800 Subject: [PATCH] implement framework of seq_pool_op and its unitest --- ...nce_avg_pool_op.cc => sequence_pool_op.cc} | 61 ++++++++++++------ ...nce_avg_pool_op.cu => sequence_pool_op.cu} | 9 ++- ...uence_avg_pool_op.h => sequence_pool_op.h} | 59 ++++++++++++----- .../v2/framework/tests/test_seq_pool.py | 63 ++++++++++++++----- 4 files changed, 139 insertions(+), 53 deletions(-) rename paddle/operators/{sequence_avg_pool_op.cc => sequence_pool_op.cc} (55%) rename paddle/operators/{sequence_avg_pool_op.cu => sequence_pool_op.cu} (74%) rename paddle/operators/{sequence_avg_pool_op.h => sequence_pool_op.h} (62%) diff --git a/paddle/operators/sequence_avg_pool_op.cc b/paddle/operators/sequence_pool_op.cc similarity index 55% rename from paddle/operators/sequence_avg_pool_op.cc rename to paddle/operators/sequence_pool_op.cc index 9815b8f3a8d..2b9875b786c 100644 --- a/paddle/operators/sequence_avg_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -12,22 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/sequence_avg_pool_op.h" +#include "paddle/operators/sequence_pool_op.h" namespace paddle { namespace operators { -class SequenceAvgPoolOp : public framework::OperatorWithKernel { +class SequencePoolOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SequencePoolOp should not be null."); PADDLE_ENFORCE_NOT_NULL( ctx.OutputVar("Out"), - "Output(Out) of SequenceAvgPoolOp should not be null."); + "Output(Out) of SequencePoolOp should not be null."); auto* x = ctx.Input("X"); auto dims = x->dims(); @@ -42,21 +42,44 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel { } }; -class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker { +class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequenceAvgPoolOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) + SequencePoolOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input of SequenceAvgPoolOp."); - AddOutput("Out", "The output of SequenceAvgPoolOp."); + AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); + AddOutput("Out", + "A LoDTensor, the variable-length output of SequencePoolOp."); + AddAttr( + "strategy", + "(int, default AVERAGE) the pooling strategy of SequencePoolOp.") + .SetDefault(AVERAGE) + .InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST}); AddComment(R"DOC( - SequenceAvgPoolOp averages features of all time-steps of each instance. - More detailed comments will be added later. + SequencePoolOp pools features of all time-steps of each instance. + + For a mini-batch of 3 variable lengths sentences, containing 2, 3, and 2 words: + + X = [[1, 3], [2, 4, 6], [5, 1]], + + and X->lod()[0] = [0, 2, 5, 7] + + then, for different strategy, we get: + + - AVERAGE: Out = [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2 + - SUM: Out = [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1 + - SQRT: Out = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), + 4.24=(5+1)/sqrt(2) + - MAX: Out = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) + - LAST: Out = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) + - FIRST: Out = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) + + and X->lod() is nullptr. )DOC"); } }; -class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { +class SequencePoolGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -84,12 +107,10 @@ class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp, - ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad, - ops::SequenceAvgPoolGradOp); +REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, + sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OP_CPU_KERNEL( - sequence_avg_pool, - ops::SequenceAvgPoolKernel); + sequence_pool, ops::SequencePoolKernel); REGISTER_OP_CPU_KERNEL( - sequence_avg_pool_grad, - ops::SequenceAvgPoolGradKernel); + sequence_pool_grad, + ops::SequencePoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.cu b/paddle/operators/sequence_pool_op.cu similarity index 74% rename from paddle/operators/sequence_avg_pool_op.cu rename to paddle/operators/sequence_pool_op.cu index bc9d1611fcc..66850772d50 100644 --- a/paddle/operators/sequence_avg_pool_op.cu +++ b/paddle/operators/sequence_pool_op.cu @@ -14,12 +14,11 @@ #define EIGEN_USE_GPU -#include "paddle/operators/sequence_avg_pool_op.h" +#include "paddle/operators/sequence_pool_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - sequence_avg_pool, - ops::SequenceAvgPoolKernel); + sequence_pool, ops::SequencePoolKernel); REGISTER_OP_GPU_KERNEL( - sequence_avg_pool_grad, - ops::SequenceAvgPoolGradKernel); + sequence_pool_grad, + ops::SequencePoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.h b/paddle/operators/sequence_pool_op.h similarity index 62% rename from paddle/operators/sequence_avg_pool_op.h rename to paddle/operators/sequence_pool_op.h index ebe0956344e..199b4430f7c 100644 --- a/paddle/operators/sequence_avg_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -28,54 +28,85 @@ template using EigenMatrix = framework::EigenMatrix; +enum SeqPoolType { + AVERAGE = 0, + SUM = 1, + SQRT = 2, // square_root_n + MAX = 3, + LAST = 4, + FIRST = 5 +}; + template -class SequenceAvgPoolKernel : public framework::OpKernel { +class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); + int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod(); + auto lod = in->lod()[0]; int64_t w = in->numel() / dims[0]; out->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { - Tensor in_t = in->Slice(static_cast(lod[0][i]), - static_cast(lod[0][i + 1])); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + Tensor in_t = + in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); Tensor out_t = out->Slice(i, i + 1); - int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); - out_e.device(place) = in_e.mean(Eigen::array({{0}})); + + switch (strategy) { + case AVERAGE: + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + break; + case SUM: + out_e.device(place) = in_e.sum(Eigen::array({{0}})); + break; + default: + LOG(FATAL) << "unsupported pooling strategy"; + } } } }; template -class SequenceAvgPoolGradKernel : public framework::OpKernel { +class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); + int strategy = context.Attr("strategy"); auto dims = in->dims(); - auto lod = in->lod(); + auto lod = in->lod()[0]; int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); auto place = context.GetEigenDevice(); - for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { - auto in_g_t = in_g->Slice(static_cast(lod[0][i]), - static_cast(lod[0][i + 1])); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + auto in_g_t = in_g->Slice(static_cast(lod[i]), + static_cast(lod[i + 1])); auto out_g_t = out_g->Slice(i, i + 1); - int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); Eigen::DSizes bcast(h, 1); - in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + + switch (strategy) { + case AVERAGE: + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + break; + case SUM: + in_g_e.device(place) = (out_g_e).broadcast(bcast); + break; + default: + LOG(FATAL) << "unsupported pooling strategy"; + } } } }; diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index cf864936af6..211086e5f4d 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -3,20 +3,37 @@ import numpy as np from op_test import OpTest -class TestSeqAvgPool1D(OpTest): - def setUp(self): - self.op_type = 'sequence_avg_pool' +class SeqPoolType(OpTest): + AVERAGE = 0 + SUM = 1 + SQRT = 2 + MAX = 3 + LAST = 4 + FIRST = 5 + + +class TestSeqAvgPool(OpTest): + def set_data(self): + self.op_type = 'sequence_pool' # one level, batch size is 4 x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') lod = [[0, 4, 5, 8, 11]] + self.inputs = {'X': (x, lod)} out = np.zeros((4, 23)).astype('float32') + self.outputs = {'Out': out} + + def compute(self): + self.attrs = {'strategy': SeqPoolType.AVERAGE} + x, lod = self.inputs['X'] + out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.mean(axis=0) - self.inputs = {'X': (x, lod)} - self.outputs = {'Out': out} + def setUp(self): + self.set_data() + self.compute() def test_check_output(self): self.check_output() @@ -25,26 +42,44 @@ class TestSeqAvgPool1D(OpTest): self.check_grad(["X"], "Out") -class TestSeqAvgPool2D(OpTest): - def setUp(self): - self.op_type = 'sequence_avg_pool' +class TestSeqAvgPool2D(TestSeqAvgPool): + def set_data(self): + self.op_type = 'sequence_pool' # one level, batch size is 4 x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') lod = [[0, 4, 5, 8, 13]] + self.inputs = {'X': (x, lod)} out = np.zeros((4, 3, 17)).astype('float32') + self.outputs = {'Out': out} + + def compute(self): + self.attrs = {'strategy': SeqPoolType.AVERAGE} + x, lod = self.inputs['X'] + out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) - self.inputs = {'X': (x, lod)} - self.outputs = {'Out': out} - def test_check_output(self): - self.check_output() +class TestSeqSumPool(TestSeqAvgPool): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SUM} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x.sum(axis=0) - def test_check_grad(self): - self.check_grad(["X"], "Out") + +class TestSeqSumPool2D(TestSeqAvgPool2D): + def compute(self): + self.attrs = {'strategy': SeqPoolType.SUM} + x, lod = self.inputs['X'] + out = self.outputs['Out'] + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) if __name__ == '__main__': -- GitLab