From 99c6f44a5a093245b9b65e7cb000e7fe5678e890 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 26 Oct 2017 16:40:29 +0800 Subject: [PATCH] follow comments --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/math/CMakeLists.txt | 4 +- ...sequence_project.cc => context_project.cc} | 6 +- ...sequence_project.cu => context_project.cu} | 6 +- .../{sequence_project.h => context_project.h} | 37 +++++----- paddle/operators/sequence_conv_op.cc | 68 +++++++++++-------- paddle/operators/sequence_conv_op.h | 54 +++++---------- .../v2/framework/tests/test_seq_conv.py | 17 +++-- 8 files changed, 90 insertions(+), 104 deletions(-) rename paddle/operators/math/{sequence_project.cc => context_project.cc} (79%) rename paddle/operators/math/{sequence_project.cu => context_project.cu} (80%) rename paddle/operators/math/{sequence_project.h => context_project.h} (89%) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index c9a93cd653..afe772dff1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -128,7 +128,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) -op_library(sequence_conv_op DEPS sequence_project) +op_library(sequence_conv_op DEPS context_project) op_library(lstm_op DEPS sequence2batch lstm_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index a3a744e5f7..40cc177d0f 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -9,7 +9,7 @@ if(WITH_GPU) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) - nv_library(sequence_project SRCS sequence_project.cc sequence_project.cu DEPS device_context) + nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) else() @@ -19,7 +19,7 @@ else() cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) - cc_library(sequence_project SRCS sequence_project.cc DEPS device_context) + cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() diff --git a/paddle/operators/math/sequence_project.cc b/paddle/operators/math/context_project.cc similarity index 79% rename from paddle/operators/math/sequence_project.cc rename to paddle/operators/math/context_project.cc index d478ea6379..f82ea5d7be 100644 --- a/paddle/operators/math/sequence_project.cc +++ b/paddle/operators/math/context_project.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/sequence_project.h" +#include "paddle/operators/math/context_project.h" namespace paddle { namespace operators { namespace math { -template class SequenceProjectFunctor; -template class SequenceProjectFunctor; +template class ContextProjectFunctor; +template class ContextProjectFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/sequence_project.cu b/paddle/operators/math/context_project.cu similarity index 80% rename from paddle/operators/math/sequence_project.cu rename to paddle/operators/math/context_project.cu index e049ebfcb8..04eeed543c 100644 --- a/paddle/operators/math/sequence_project.cu +++ b/paddle/operators/math/context_project.cu @@ -14,14 +14,14 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/operators/math/sequence_project.h" +#include "paddle/operators/math/context_project.h" namespace paddle { namespace operators { namespace math { -template class SequenceProjectFunctor; -template class SequenceProjectFunctor; +template class ContextProjectFunctor; +template class ContextProjectFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/sequence_project.h b/paddle/operators/math/context_project.h similarity index 89% rename from paddle/operators/math/sequence_project.h rename to paddle/operators/math/context_project.h index 1d799a0c1c..e37f3a5bf2 100644 --- a/paddle/operators/math/sequence_project.h +++ b/paddle/operators/math/context_project.h @@ -23,31 +23,29 @@ namespace paddle { namespace operators { namespace math { -// template -// using EigenVector = framework::EigenVector; - template using EigenMatrix = framework::EigenMatrix; /* - * \brief SequenceProject projects features of context_length time-steps of each - * instance. - * + * \brief Context projection concatenate features in adjacent time steps in + * a sequence. The i-th row of the output is the concatenation of + * context_length rows of the input. The context_length rows are the + * consecutive rows from the i+shift_start row. + * \param in Input data. - * \param inShape The shape of Input data, + * \param Shape The shape of Input data, * [minibatch, number_of_input_features]. - * \param inShape A float LoDTensor. + * \param type A float LoDTensor. * * \param padding_data Padding data. - * \param inShape The shape of Padding data, + * \param Shape The shape of Padding data, * [up_pad + down_pad, number_of_input_features]. - * \param inShape A float LoDTensor. + * \param type A float Tensor. * * \param col Col data. - * \param inShape The shape of Col data, - * [minibatch, 1]. - * \param inShape A float LoDTensor. + * \param Shape The shape of Col data, + * [minibatch, context_length * number_of_input_features]. + * \param type A float Tensor. * * For a mini-batch of 2 variable lengths sentences, containing 3, and 1 * time-steps: @@ -87,7 +85,7 @@ using EigenMatrix = framework::EigenMatrix; */ template -class SequenceProjectFunctor { +class ContextProjectFunctor { public: void operator()(const platform::DeviceContext& context, framework::LoDTensor& in, framework::Tensor& padding_data, @@ -147,8 +145,7 @@ class SequenceProjectFunctor { /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); } - out_t.Resize(framework::make_ddim( - {sequence_height, context_length * sequence_width})); + out_t.Resize({sequence_height, context_length * sequence_width}); } } } @@ -162,8 +159,7 @@ class SequenceProjectFunctor { sequence_height = static_cast(out_t.dims()[0]); // add up trainable data - out_t.Resize(framework::make_ddim( - {sequence_height * context_length, sequence_width})); + out_t.Resize({sequence_height * context_length, sequence_width}); if (up_pad > 0) { // add up pad int padding_rows = std::min( @@ -223,8 +219,7 @@ class SequenceProjectFunctor { } } } - out_t.Resize(framework::make_ddim( - {sequence_height, context_length * sequence_width})); + out_t.Resize({sequence_height, context_length * sequence_width}); } } } diff --git a/paddle/operators/sequence_conv_op.cc b/paddle/operators/sequence_conv_op.cc index 463bca7a44..139000c561 100644 --- a/paddle/operators/sequence_conv_op.cc +++ b/paddle/operators/sequence_conv_op.cc @@ -38,10 +38,9 @@ class SequenceConvOp : public framework::OperatorWithKernel { auto filter_dims = ctx->GetInputDim("Filter"); PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2, "Input(X, Filter) should be 2-D tensor."); - PADDLE_ENFORCE( - filter_dims[0] == context_length && filter_dims[1] == in_dims[1], - "Filter's shape should be (context_length x " - "number_of_input_features)."); + PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1], + "Filter's height should be context_length * " + "number_of_input_features ."); if (padding_trainable) { PADDLE_ENFORCE( @@ -66,8 +65,9 @@ class SequenceConvOp : public framework::OperatorWithKernel { "and 'context_length'."); } - in_dims[1] = 1; + in_dims[1] = filter_dims[1]; ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", "Out"); } }; @@ -101,35 +101,51 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { SequenceConvOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "(A float LoDTensor) the input of SequenceConvOp, a vector of " - "2-D matrix of size (minibatch, number_of_input_features)."); + AddInput( + "X", + "(LoDTensor) the input(X) is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T, D), where, T is the " + "total time steps in this mini-batch, D is the input feature size."); AddInput("PaddingData", - "(Tensor) the input of SequenceConvOp, a vector of " - "2-D matrix of size (up_pad + down_pad, " - "number_of_input_features). ") + "(Tensor, optional) the input(PaddingData) is an optional " + "parameter, and it is learnable. " + "This is a tensor with shape (N, D), where N is the " + "top_pad + bottom_pad, D is the input feature size. In order to " + "ensure the equal length of sequence before and after " + "convolution, it is necessary to fill the top and bottom of each " + "sequence according to context_length, context_stride and " + "context_start") .AsDispensable(); AddInput("Filter", - "(Tensor) the input of SequenceConvOp, a vector of " - "2-D matrix of size (context_length x number_of_input_features)."); - AddOutput("Out", - "(A float LoDTensor) the output of SequenceConvOp, a vector " - "of 2-D matrix of size (minibatch, 1)."); + "(Tensor) the input(Filter) is an learnable parameter." + "This is a tensor with shape (N, D), where N is the " + "context_length, D is the output feature size."); + AddOutput( + "Out", + "(LoDTensor) the output(Out) is a LodTensor, which support " + "variable-time length output sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T, D), where, T is the " + "total time steps in this mini-batch, D is the output feature size."); AddAttr("padding_trainable", "(bool, default false) the padding data of SequenceConvOp " "is trainable or not.") .SetDefault(false); AddAttr("context_length", - "(int, default 3) the context_length of SequenceConvOp.") + "(int, default 3) the context_length of SequenceConvOp is the " + "height of the convolution kernel.") .SetDefault(3) .GreaterThan(0); AddAttr("context_start", - "(int, default 0) the context_start of SequenceConvOp.") + "(int, default 0) the context_start of SequenceConvOp " + "represents the beginning of the convolution of the number of " + "rows of sequence, which can be negative.") .SetDefault(0); AddAttr("context_stride", - "(int, default 1) the context_stride of SequenceConvOp. " - "Currently, sequence_project_op only support " + "(int, default 1) the context_stride of SequenceConvOp " + "represents the step length of convolution. " + "Currently, SequenceConvOp only supports" "context_stride=1.") .SetDefault(1) .GreaterThan(0); @@ -139,14 +155,10 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { context_length time-steps of each instance. The convolution operation calculates the output based on the input, filter and strides, paddings parameters. The size of each dimension of the - parameters is checked in the infer-shape. - -Example: - Input: - X shape: (minibatch, number_of_input_features) - Filter shape: (context_length, number_of_input_features) - Output: - Out shape: (minibatch, 1) + parameters is checked in the infer-shape. In order to ensure the equal + length of sequence before and after convolution, it is necessary to fill + the top and bottom of each sequence according to context_length, + context_stride and context_start. )DOC"); } }; diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index 6907c011a0..cd8a8d4cea 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -15,20 +15,14 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/context_project.h" #include "paddle/operators/math/math_function.h" -#include "paddle/operators/math/sequence_project.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -// template -// using EigenVector = framework::EigenVector; -template -using EigenMatrix = framework::EigenMatrix; template class SequenceConvKernel : public framework::OpKernel { @@ -39,7 +33,7 @@ class SequenceConvKernel : public framework::OpKernel { auto filter = *context.Input("Filter"); out->mutable_data(context.GetPlace()); - // out->set_lod(in->lod()); + context.ShareLoD("X", "Out"); int context_start = context.Attr("context_start"); int context_length = context.Attr("context_length"); @@ -60,17 +54,16 @@ class SequenceConvKernel : public framework::OpKernel { int sequence_width; sequence_width = static_cast(in->dims()[1]); - // use col_shape in the im2col calculation + // Use col_shape in the im2col calculation. framework::DDim col_shape = {in->dims()[0], sequence_width * context_length}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); + math::SetConstant set_zero; // Because if padding_trainable is false, padding data should be zeros. - auto temp = framework::EigenVector::Flatten(col); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); + set_zero(context.device_context(), &col, static_cast(0)); - paddle::operators::math::SequenceProjectFunctor + paddle::operators::math::ContextProjectFunctor seq_project_functor; LoDTensor* input = const_cast(in); Tensor* pad_data = const_cast(padding_data); @@ -79,9 +72,8 @@ class SequenceConvKernel : public framework::OpKernel { padding_trainable, context_start, context_length, context_stride, up_pad, down_pad, false, false, false); - filter.Resize(framework::make_ddim({context_length * sequence_width, 1})); math::matmul(context.device_context(), col, false, filter, false, - T(1.0), out, T(0.0)); + static_cast(1.0), out, static_cast(0.0)); } }; @@ -102,7 +94,6 @@ class SequenceConvGradKernel : public framework::OpKernel { int context_stride = context.Attr("context_stride"); bool padding_trainable = context.Attr("padding_trainable"); - // InferShape by in_lod PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, "Only support one level sequence now."); auto lod_g_level_0 = in->lod()[0]; @@ -111,6 +102,7 @@ class SequenceConvGradKernel : public framework::OpKernel { int down_pad = std::max(0, context_start + context_length - 1); int sequence_width = static_cast(in->dims()[1]); + math::SetConstant set_zero; // use col_shape in the im2col calculation framework::DDim col_shape = {in->dims()[0], sequence_width * context_length}; @@ -119,22 +111,17 @@ class SequenceConvGradKernel : public framework::OpKernel { if (in_g || filter_g || (padding_trainable && padding_data_g)) { col.mutable_data(col_shape, context.GetPlace()); // Because if padding_trainable is false, padding data should be zeros. - auto temp = framework::EigenVector::Flatten(col); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); - + set_zero(context.device_context(), &col, static_cast(0)); math::matmul(context.device_context(), *out_g, false, *filter, true, T(1.0), &col, T(1.0)); } - paddle::operators::math::SequenceProjectFunctor + paddle::operators::math::ContextProjectFunctor seq_project_functor; if (in_g) { in_g->mutable_data(context.GetPlace()); in_g->set_lod(in->lod()); - - math::SetConstant functor; - functor(context.device_context(), in_g, 0); + set_zero(context.device_context(), in_g, static_cast(0)); seq_project_functor(context.device_context(), *in_g, *padding_data_g, col, padding_trainable, context_start, context_length, @@ -143,9 +130,7 @@ class SequenceConvGradKernel : public framework::OpKernel { if (padding_trainable && padding_data_g) { padding_data_g->mutable_data(context.GetPlace()); - - math::SetConstant functor; - functor(context.device_context(), padding_data_g, 0); + set_zero(context.device_context(), padding_data_g, static_cast(0)); LoDTensor* input = const_cast(in); seq_project_functor(context.device_context(), *input, *padding_data_g, @@ -155,12 +140,10 @@ class SequenceConvGradKernel : public framework::OpKernel { if (filter_g) { filter_g->mutable_data(context.GetPlace()); + set_zero(context.device_context(), filter_g, static_cast(0)); - math::SetConstant functor; - functor(context.device_context(), filter_g, 0); - - Tensor filter_grad_ = *filter_g; - LoDTensor out_grad_ = *out_g; + Tensor filter_grad = *filter_g; + LoDTensor out_grad = *out_g; const Tensor* padding_data = nullptr; if (padding_trainable) { @@ -177,11 +160,8 @@ class SequenceConvGradKernel : public framework::OpKernel { context_stride, up_pad, down_pad, false, false, false); - filter_grad_.Resize( - framework::make_ddim({context_length * sequence_width, 1})); - - math::matmul(context.device_context(), col, true, out_grad_, - false, T(1.0), &filter_grad_, T(1.0)); + math::matmul(context.device_context(), col, true, out_grad, + false, T(1.0), &filter_grad, T(1.0)); } } }; diff --git a/python/paddle/v2/framework/tests/test_seq_conv.py b/python/paddle/v2/framework/tests/test_seq_conv.py index b7b3c0811c..f0337c20a9 100644 --- a/python/paddle/v2/framework/tests/test_seq_conv.py +++ b/python/paddle/v2/framework/tests/test_seq_conv.py @@ -20,8 +20,9 @@ class TestSeqProject(OpTest): # one level, batch size x = np.random.uniform(0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32') - w = np.random.uniform( - 0.1, 1, [self.context_length, self.input_size[1]]).astype('float32') + w = np.random.uniform(0.1, 1, [ + self.context_length * self.input_size[1], self.output_represention + ]).astype('float32') begin_pad = np.max([0, -self.context_start]) end_pad = np.max([0, self.context_start + self.context_length - 1]) @@ -49,7 +50,8 @@ class TestSeqProject(OpTest): 'padding_trainable': self.padding_trainable, 'context_stride': self.context_stride } - out = np.zeros((self.input_size[0], 1)).astype('float32') + out = np.zeros( + (self.input_size[0], self.output_represention)).astype('float32') self.outputs = {'Out': out} self.compute() @@ -95,13 +97,7 @@ class TestSeqProject(OpTest): out[out_begin:out_end, j * self.input_size[1]:(j + 1) * self.input_size[1]] += in_sub - filter_dim = filter.shape - output_dim = self.outputs['Out'].shape - filter.shape = filter_dim[0] * filter_dim[1] - self.outputs['Out'].shape = (output_dim[0], ) np.dot(out, filter, out=self.outputs['Out']) - filter.shape = filter_dim - self.outputs['Out'].shape = output_dim def test_check_output(self): self.check_output() @@ -166,6 +162,7 @@ class TestSeqProject(OpTest): self.input_size = [self.input_row, 23] self.lod = [[0, 4, 5, 8, self.input_row]] + self.output_represention = 8 # output feature size class TestSeqProjectCase1(TestSeqProject): @@ -178,6 +175,7 @@ class TestSeqProjectCase1(TestSeqProject): self.input_size = [self.input_row, 23] self.lod = [[0, 4, 5, 8, self.input_row]] + self.output_represention = 8 # output feature size class TestSeqProjectCase2(TestSeqProject): @@ -193,6 +191,7 @@ class TestSeqProjectCase2(TestSeqProject): del idx[0] self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + [self.input_size[0]]] + self.output_represention = 8 # output feature size if __name__ == '__main__': -- GitLab