From afc6343e6f377600d0ee2a90cc6673fcc46a1a93 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 2 Nov 2017 17:13:40 +0800 Subject: [PATCH] Refine sequence max-pooling and add unit testing of gradient check. --- paddle/operators/CMakeLists.txt | 2 + paddle/operators/math/CMakeLists.txt | 2 + paddle/operators/math/sequence_pooling.cc | 103 +++++++++++++ paddle/operators/math/sequence_pooling.cu | 136 ++++++++++++++++++ paddle/operators/math/sequence_pooling.h | 45 ++++++ paddle/operators/sequence_pool_op.cc | 21 ++- paddle/operators/sequence_pool_op.h | 39 ++--- .../v2/framework/tests/test_seq_pool.py | 45 ++++-- 8 files changed, 362 insertions(+), 31 deletions(-) create mode 100644 paddle/operators/math/sequence_pooling.cc create mode 100644 paddle/operators/math/sequence_pooling.cu create mode 100644 paddle/operators/math/sequence_pooling.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 81d92ec6f4..e584b9da65 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -141,6 +141,7 @@ set(DEPS_OPS pool_with_index_op nccl_op sequence_conv_op + sequence_pool_op lstm_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) @@ -153,6 +154,7 @@ if(WITH_GPU) op_library(nccl_op DEPS nccl_common) endif() op_library(sequence_conv_op DEPS context_project) +op_library(sequence_pool_op DEPS sequence_pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc DEPS net_op tensor_array) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 40cc177d0f..ca6a38ea10 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,6 +8,7 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) + nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) @@ -18,6 +19,7 @@ else() cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) + nv_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) diff --git a/paddle/operators/math/sequence_pooling.cc b/paddle/operators/math/sequence_pooling.cc new file mode 100644 index 0000000000..a401f115ee --- /dev/null +++ b/paddle/operators/math/sequence_pooling.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence_pooling.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MaxSeqPoolFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output, + framework::Tensor* index) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + auto idx_dims = index->dims(); + PADDLE_ENFORCE_GT(in_dims.size(), 1UL); + PADDLE_ENFORCE_GT(out_dims.size(), 1UL); + for (size_t i = 1; i < in_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); + } + PADDLE_ENFORCE_EQ(idx_dims, out_dims); + + auto starts = input.lod()[0]; + const T* in_data = input.data(); + T* out_data = output->data(); + int* max_index = index->data(); + + int64_t num_seq = out_dims[0]; + int64_t dim = output->numel() / num_seq; + for (int64_t i = 0; i < num_seq; ++i) { + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = in_data[starts[i] * dim + k]; + max_index[i * dim + k] = starts[i]; + } + for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { + for (int64_t k = 0; k < dim; ++k) { + if (in_data[j * dim + k] > out_data[i * dim + k]) { + out_data[i * dim + k] = in_data[j * dim + k]; + max_index[i * dim + k] = j; + } + } + } + } + } +}; + +template +class MaxSeqPoolGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& out_grad, + const framework::Tensor& index, + framework::LoDTensor* in_grad) { + auto og_dims = out_grad.dims(); + auto ig_dims = in_grad->dims(); + auto idx_dims = index.dims(); + PADDLE_ENFORCE_GT(og_dims.size(), 1UL); + PADDLE_ENFORCE_GT(ig_dims.size(), 1UL); + for (size_t i = 1; i < og_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]); + } + PADDLE_ENFORCE_EQ(idx_dims, og_dims); + + const T* og_data = out_grad.data(); + const int* max_index = index.data(); + T* ig_data = in_grad->data(); + + SetConstant set_zero; + set_zero(context, in_grad, static_cast(0.0)); + int64_t num_seq = og_dims[0]; + int64_t dim = out_grad.numel() / num_seq; + for (size_t i = 0; i < num_seq; ++i) { + for (size_t j = 0; j < dim; ++j) { + int step_id = max_index[i * dim + j]; + ig_data[step_id * dim + j] = og_data[i * dim + j]; + } + } + } +}; + +template class MaxSeqPoolFunctor; +template class MaxSeqPoolFunctor; +template class MaxSeqPoolGradFunctor; +template class MaxSeqPoolGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_pooling.cu b/paddle/operators/math/sequence_pooling.cu new file mode 100644 index 0000000000..bd823c15c9 --- /dev/null +++ b/paddle/operators/math/sequence_pooling.cu @@ -0,0 +1,136 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence_pooling.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX __FLT_MAX__ + +template +__global__ void KeMaxSequencePool(const T* input, const size_t* starts, + T* output, int* index, int64_t num_seq, + int64_t dim) { + int dim_idx = threadIdx.x; + int seq_id = blockIdx.x; + if (seq_id >= num_seq) return; + size_t start = starts[seq_id]; + size_t end = starts[seq_id + 1]; + + for (int i = dim_idx; i < dim; i += blockDim.x) { + T max_val = static_cast(-FLT_MAX); + int max_id = -1; + for (size_t step_id = start; step_id < end; step_id++) { + if (max_val < input[step_id * dim + i]) { + max_val = input[step_id * dim + i]; + max_id = step_id; + } + } + output[seq_id * dim + i] = max_val; + index[seq_id * dim + i] = max_id; + } +} + +template +class MaxSeqPoolFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output, + framework::Tensor* index) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + auto idx_dims = index->dims(); + PADDLE_ENFORCE_GT(in_dims.size(), 1UL); + PADDLE_ENFORCE_GT(out_dims.size(), 1UL); + for (size_t i = 1; i < in_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); + } + PADDLE_ENFORCE_EQ(idx_dims, out_dims); + + auto starts = input.lod()[0]; + const T* in_data = input.data(); + T* out_data = output->data(); + int* max_index = index->data(); + + int64_t num_seq = out_dims[0]; + int64_t dim = output->numel() / num_seq; + + dim3 threads(256, 1); + dim3 grid(num_seq, 1); + auto stream = + reinterpret_cast(context).stream(); + KeMaxSequencePool<<>>( + in_data, starts.data(), out_data, max_index, num_seq, dim); + } +}; + +template +__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index, + T* in_grad, int64_t num_seq, + int64_t dim) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int col_idx = idx % dim; + if (idx < num_seq * dim) { + int step_id = max_index[idx]; + in_grad[step_id * dim + col_idx] = out_grad[idx]; + } +} + +template +class MaxSeqPoolGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& out_grad, + const framework::Tensor& index, + framework::LoDTensor* in_grad) { + auto og_dims = out_grad.dims(); + auto idx_dims = index.dims(); + auto ig_dims = in_grad->dims(); + PADDLE_ENFORCE_GT(og_dims.size(), 1UL); + PADDLE_ENFORCE_GT(ig_dims.size(), 1UL); + for (size_t i = 1; i < og_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]); + } + PADDLE_ENFORCE_EQ(idx_dims, og_dims); + + const T* og_data = out_grad.data(); + const int* max_index = index.data(); + T* ig_data = in_grad->data(); + + SetConstant set_zero; + set_zero(context, in_grad, static_cast(0.0)); + int64_t num_seq = og_dims[0]; + int64_t dim = out_grad.numel() / num_seq; + + unsigned int blocks = (num_seq * dim + 128 - 1) / 128; + dim3 threads(128, 1); + dim3 grid(blocks, 1); + auto stream = + reinterpret_cast(context).stream(); + KeMaxSequencePoolGrad<<>>( + og_data, max_index, ig_data, num_seq, dim); + } +}; + +template class MaxSeqPoolFunctor; +template class MaxSeqPoolFunctor; +template class MaxSeqPoolGradFunctor; +template class MaxSeqPoolGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_pooling.h b/paddle/operators/math/sequence_pooling.h new file mode 100644 index 0000000000..35dfe26de1 --- /dev/null +++ b/paddle/operators/math/sequence_pooling.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX __FLT_MAX__ + +template +class MaxSeqPoolFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& input, framework::Tensor* output, + framework::Tensor* index); +}; + +template +class MaxSeqPoolGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& out_grad, + const framework::Tensor& index, + framework::LoDTensor* in_grad); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 29d19df108..731da8848d 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -27,6 +27,11 @@ class SequencePoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequencePoolOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (ctx->Attrs().Get("pooltype") == "MAX") { + PADDLE_ENFORCE(ctx->HasOutput("MaxIndex"), + "Output(MaxIndex) of SequencePoolOp should not be null."); + ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X")); + } } }; @@ -35,13 +40,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { SequencePoolOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor), the variable-length input of SequencePoolOp"); + AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp"); AddOutput("Out", - "(Tensor), output of SequencePoolOp, which does not contain LoD " + "(Tensor) The output of SequencePoolOp does not contain LoD " "infomation."); + AddOutput("MaxIndex", + "(Tensor) This tensor is used for the max-pooling " + "of sequence to record the max indexes.") + .AsIntermediate(); AddAttr( "pooltype", - "(int, default AVERAGE) the pooling pooltype of SequencePoolOp.") + "(int, default AVERAGE) The pooling pooltype of SequencePoolOp.") .SetDefault("AVERAGE"); AddComment(R"DOC( SequencePoolOp pools features of all time-steps of each instance. @@ -92,6 +101,12 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return framework::ToDataType(ctx.Input("X")->type()); + } }; } // namespace operators diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index e0e0493fe0..2b8a25c241 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence_pooling.h" namespace paddle { namespace operators { @@ -34,7 +35,7 @@ class SequencePoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* out = context.Output("Out"); + auto* out = context.Output("Out"); std::string pooltype = context.Attr("pooltype"); auto dims = in->dims(); @@ -53,6 +54,16 @@ class SequencePoolKernel : public framework::OpKernel { auto lod_level_0 = lod[0]; out->mutable_data(context.GetPlace()); + + if (pooltype == "MAX") { + math::MaxSeqPoolFunctor max_pool; + auto* index = context.Output("MaxIndex"); + index->Resize({dims}); + index->mutable_data(context.GetPlace()); + max_pool(context.device_context(), *in, out, index); + return; + } + auto place = context.GetEigenDevice(); for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { Tensor in_t = in->Slice(static_cast(lod_level_0[i]), @@ -69,8 +80,6 @@ class SequencePoolKernel : public framework::OpKernel { } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); - } else if (pooltype == "MAX") { - out_e.device(place) = in_e.maximum(Eigen::array({{0}})); } else if (pooltype == "LAST") { out_e.device(place) = in_e.chip(h - 1, 0); } else if (pooltype == "FIRST") { @@ -87,8 +96,8 @@ class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); + auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); - auto* out_g = context.Input(framework::GradVarName("Out")); std::string pooltype = context.Attr("pooltype"); auto dims = in->dims(); @@ -96,6 +105,14 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t w = in->numel() / dims[0]; in_g->mutable_data(context.GetPlace()); + + if (pooltype == "MAX") { + math::MaxSeqPoolGradFunctor max_pool_grad; + auto* index = context.Input("MaxIndex"); + max_pool_grad(context.device_context(), *out_g, *index, in_g); + return; + } + if (pooltype == "LAST" || pooltype == "FIRST") { // set X@Grad be zero at first when pooltype is LAST/FIRST math::SetConstant functor; @@ -118,20 +135,6 @@ class SequencePoolGradKernel : public framework::OpKernel { } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); - } else if (pooltype == "MAX") { - auto in_t = - in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); - Eigen::Map> - in_t_map(in_t.data(), h, w); - int row_id; - Eigen::array extents{{1, 1}}; - for (int col_id = 0; col_id < w; col_id++) { - in_t_map.col(col_id).maxCoeff(&row_id); - Eigen::array in_offsets{{row_id, col_id}}; - Eigen::array out_offsets{{0, col_id}}; - in_g_e.slice(in_offsets, extents).device(place) = - out_g_e.slice(out_offsets, extents); - } } else if (pooltype == "LAST") { in_g_e.chip(h - 1, 0).device(place) = out_g_e; } else if (pooltype == "FIRST") { diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index efc4920124..512d8b315f 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -29,6 +29,9 @@ class TestSeqAvgPool(OpTest): self.check_output() def test_check_grad(self): + # Remove MaxIndex after check_grad is refined. + self.outputs['MaxIndex'] = \ + np.zeros(self.outputs['Out'].shape).astype('int32') self.check_grad(["X"], "Out") @@ -85,31 +88,53 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): out[i] = np.reshape(sub_x.sum(axis=0) / np.sqrt(len), (3, 17)) def test_check_grad(self): + # Remove MaxIndex after check_grad is refined. + self.outputs['MaxIndex'] = \ + np.zeros(self.outputs['Out'].shape).astype('int32') self.check_grad(["X"], "Out", max_relative_error=0.06) class TestSeqMaxPool(TestSeqAvgPool): + def set_data(self): + self.op_type = 'sequence_pool' + x = np.random.uniform(0.1, 1, [13, 23]).astype('float32') + lod = [[0, 4, 5, 8, 13]] + for i in range(4): + l = lod[0][i + 1] - lod[0][i] + x[lod[0][i] + np.random.randint(l), :] += 2.0 + + self.inputs = {'X': (x, lod)} + + out = np.zeros((4, 23)).astype('float32') + self.outputs = {'Out': out} + return x, lod, out + def compute(self, x, lod, out): self.attrs = {'pooltype': "MAX"} for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = np.amax(sub_x, axis=0) - def test_check_grad(self): - # Remove MaxPool2D from gradient check to confirm the success of CI. - return - class TestSeqMaxPool2D(TestSeqAvgPool2D): + def set_data(self): + self.op_type = 'sequence_pool' + x = np.random.uniform(0.1, 1, [13, 3, 11]).astype('float32') + lod = [[0, 4, 5, 8, 13]] + self.inputs = {'X': (x, lod)} + for i in range(4): + l = lod[0][i + 1] - lod[0][i] + x[lod[0][i] + np.random.randint(l), :] += 1.0 + + out = np.zeros((4, 3, 11)).astype('float32') + self.outputs = {'Out': out} + return x, lod, out + def compute(self, x, lod, out): self.attrs = {'pooltype': "MAX"} for i in range(4): - sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) - out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 17)) - - def test_check_grad(self): - # Remove MaxPool2D from gradient check to confirm the success of CI. - return + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 11)) + out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) class TestSeqLastPool(TestSeqAvgPool): -- GitLab