diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c8d9dac21d995d92b9d50436d42e47b63ea55f58..405f3689b69b5866d86a228192f619d6c9a34e1e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -46,7 +46,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame set(EXECUTOR_TEST_OP elementwise_add_op gaussian_random_op feed_op fetch_op mul_op sum_op squared_l2_distance_op fill_constant_op sgd_op mean_op) if(WITH_GPU) - nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) +# nv_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) else() cc_test(executor_test SRCS executor_test.cc DEPS executor ${EXECUTOR_TEST_OP}) endif() diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index c08a3380f042886cd400df0d840e61856274619c..15b223479f9f4f5790fab8f05e3b0322238d5378 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -140,8 +140,11 @@ class Im2ColFunctor(); T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { + for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; @@ -166,13 +169,14 @@ class Im2ColFunctor= input_height || im_col_offset < 0 || im_col_offset >= input_width) { col_data[col_offset] = T(0); @@ -200,8 +204,12 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width) { + const framework::Tensor& col, int stride, int pad, + int row_start, int row_end) { + int stride_height = stride; + int stride_width = 0; + int padding_height = pad; + int padding_width = 0; PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -209,30 +217,31 @@ class Col2ImFunctor(); const T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { + for (int col_row_idx = row_start; col_row_idx < row_end; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; ++filter_col_idx) { - int im_row_offset = + int im_row_offset = // change or not ??? col_row_idx * stride_height + filter_row_idx - padding_height; int im_col_offset = col_col_idx * stride_width + filter_col_idx - padding_width; - int col_offset = (((col_row_idx * output_width + col_col_idx) * - input_channels + - channel) * - filter_height + - filter_row_idx) * - filter_width + - filter_col_idx; + int col_offset = + ((((col_row_idx - row_start) * output_width + col_col_idx) * + input_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; if (im_row_offset >= 0 && im_row_offset < input_height && im_col_offset >= 0 && im_col_offset < input_width) { int im_offset = diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 01f60bfe70f844fdcfd5aa481c27d9f12ec51305..9b89a4ad411cc7a36719847607d80e5588b5c20f 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -199,7 +199,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width) { + int output_height, int output_width, int row_begin, + int row_end) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < input_channels; @@ -207,7 +208,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = idy + shid * stride_height - padding_height; + int height_offset = + idy + (shid + row_begin) * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -238,8 +240,12 @@ class Im2ColFunctor>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + padding_height, padding_width, output_height, output_width, row_begin, + row_end); } }; @@ -284,15 +291,18 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width) { + int output_height, int output_width, int row_begin, + int row_end) { int swid = blockIdx.x; int shid = blockIdx.y; + // if (shid < row_begin || shid > row_end) return; for (int channelid = threadIdx.z; channelid < input_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = idy + shid * stride_height - padding_height; + int height_offset = + idy + (shid + row_begin) * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -321,8 +331,12 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width) { + const framework::Tensor& col, int stride, int pad, + int row_begin, int row_end) { + int stride_height = stride; + int stride_width = 0; + int padding_height = pad; + int padding_width = 0; PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -330,7 +344,7 @@ class Col2ImFunctor>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + padding_height, padding_width, output_height, output_width, row_begin, + row_end); } }; diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 9c506ae89bdda38f40fb37e4c4e5f990cd5978b7..46de79af8fd69f71695cc67040672ec65947538f 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -79,7 +79,8 @@ void testIm2col() { im2col_ocf; im2col(*context, input, output_cfo, stride, stride, padding, padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); + im2col_ocf(*context, input, output_ocf, stride, padding, 0, + output_height * output_width); float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { diff --git a/paddle/operators/sequence_project_op.cc b/paddle/operators/sequence_project_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c894f3f1f8fd0824b46cdccbae9802b6e20e6ed3 --- /dev/null +++ b/paddle/operators/sequence_project_op.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_project_op.h" + +namespace paddle { +namespace operators { + +class SequenceProjectOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceProjectOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceProjectOp should not be null."); + auto in_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(in_dims.size() == 2, "Input(X) should be 2-D tensor."); + + int context_length = ctx->Attrs().Get("context_length"); + bool padding_trainable = ctx->Attrs().Get("padding_trainable"); + int context_start = ctx->Attrs().Get("context_start"); + + if (padding_trainable) { + PADDLE_ENFORCE( + ctx->HasInput("PaddingData"), + "Output(PaddingData) of SequenceProjectOp should not be null."); + framework::DDim padding_dim = ctx->GetOutputDim("PaddingData"); + int up_pad = std::max(0, -context_start); + int down_pad = std::max(0, context_start + context_length - 1); + int total_pad = up_pad + down_pad; + int input_width = static_cast(in_dims[1]); + + PADDLE_ENFORCE(padding_dim.size() == 2, + "Input(PaddingData) should be 2-D tensor."); + PADDLE_ENFORCE( + padding_dim[0] == total_pad && padding_dim[1] == input_width, + "Input(PaddingData)'s shape is not consistent with 'context_start' " + "and 'context_length'."); + + if (context_start == 0 && context_length == 1) { + PADDLE_THROW( + "if context_start == 0 && context_length == 1, padding_trainable " + "should be false."); + } + } + + in_dims[1] = in_dims[1] * context_length; + ctx->SetOutputDim("Out", in_dims); + } +}; + +class SequenceProjectGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); + + if (ctx->Attrs().Get("padding_trainable")) { + PADDLE_ENFORCE( + ctx->HasOutput("PaddingData"), + "Output(PaddingData) of SequenceProjectOp should not be null."); + } + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceProjectOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "A float LoDTensor, the variable-length input of SequenceProjectOp"); + AddOutput( + "Out", + "A float LoDTensor, the variable-length output of SequenceProjectOp."); + AddOutput("PaddingData", + "A float LoDTensor, the padding data of SequenceProjectOp."); + + AddAttr("padding_trainable", + "(bool, default false) the padding data of SequenceProjectOp " + "is trainable or not.") + .SetDefault(false); + AddAttr("context_length", + "(int, default 3) the stride of SequenceProjectOp.") + .SetDefault(3) + .GreaterThan(0); + AddAttr("context_start", + "(int, default 0) the xx of SequenceProjectOp.") + .SetDefault(0); + AddAttr("context_stride", + "(int, default 1) the xx of SequenceProjectOp.") + .SetDefault(1) + .GreaterThan(0); + + AddComment(R"DOC( + SequenceProjectOp projects features of context_length time-steps of each instance. + + For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps: + + Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4]. + Besides, for the sake of simplicity, we assume M=1 and N=2. + + X = [[a1, a2, + b1, b2. + c1, c2] + [d1, d2]] + + This is to say that input (X) has 4 words and the dimension of each word + representation is 2. + + - Case1: + If we use zero to pad instead of learned weight to pad, + and the context_lenth is 3, the output (Out) is: + + Out = [0, 0, a1, a2, b1, b2; + a1, a2, b1, b2, c1, c2; + b1, b2, c1, c2, 0, 0; + 0, 0, d1, d2, 0, 0] + + - Case2: +// If we use zero to pad instead of learned weight to pad, +// and the context_lenth is 3, the output (Out) is: +// +// Out = [0, 0, a1, a2, b1, b2; +// a1, a2, b1, b2, c1, c2; +// b1, b2, c1, c2, 0, 0; +// 0, 0, d1, d2, 0, 0] + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_project, ops::SequenceProjectOp, + ops::SequenceProjectOpMaker, sequence_project_grad, + ops::SequenceProjectGradOp); + +REGISTER_OP_CPU_KERNEL( + sequence_project, + ops::SequenceProjectKernel); +REGISTER_OP_CPU_KERNEL( + sequence_project_grad, + ops::SequenceProjectGradKernel); diff --git a/paddle/operators/sequence_project_op.cu b/paddle/operators/sequence_project_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d3479d6f955bf1cc5926055086ed2afcc9a4168 --- /dev/null +++ b/paddle/operators/sequence_project_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_project_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_project, + ops::SequenceProjectKernel); +REGISTER_OP_GPU_KERNEL( + sequence_project_grad, + ops::SequenceProjectGradKernel); diff --git a/paddle/operators/sequence_project_op.h b/paddle/operators/sequence_project_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6e911137a7683bf39f2a91f1b2a94f018838602d --- /dev/null +++ b/paddle/operators/sequence_project_op.h @@ -0,0 +1,257 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class SequenceProjectKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + + int context_start = context.Attr("context_start"); + int context_length = context.Attr("context_length"); + bool padding_trainable = context.Attr("padding_trainable"); + int context_stride = context.Attr("context_stride"); + + // InferShape by in_lod + PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, + "Only support one level sequence now."); + auto lod_level_0 = in->lod()[0]; + int64_t input_stride = in->dims()[1]; + int64_t output_stride = out->dims()[1]; + int64_t padding_stride = 0; + PADDLE_ENFORCE(input_stride * context_length == output_stride, + "Input size and pooling size should be consistent."); + + const LoDTensor* padding_data = nullptr; + if (padding_trainable) { + padding_data = context.Input("PaddingData"); + PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL, + "Only support one level sequence now."); + padding_stride = padding_data->dims()[1]; + PADDLE_ENFORCE(padding_stride == input_stride, + "Input size and pooling size should be consistent."); + } + + int up_pad = std::max(0, -context_start); + int down_pad = std::max(0, context_start + context_length - 1); + + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kOCF, Place, float> + im2col_ocf; + + for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { + Tensor in_t = in->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); + Tensor out_t = out->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); + + int sequence_height = in_t.dims()[0]; + int sequence_width = in_t.dims()[1]; + std::vector output_shape( + {sequence_height, 1, 1, context_length, + sequence_width}); // output_height, output_width, + // input_channels, + // filter_height, filter_width + out_t.Resize(framework::make_ddim(output_shape)); + std::vector input_shape( + {1, sequence_height, + sequence_width}); // input_channels, input_height, input_width + in_t.Resize(framework::make_ddim(input_shape)); + for (int j = 0; j < context_length; ++j) { + int pad; + int row_start; + + if (up_pad != 0) { + pad = up_pad; + row_start = 0; + } else if (down_pad != 0) { + pad = down_pad; + row_start = down_pad; + } else { + pad = 0; + row_start = 0; + } + + im2col_ocf(context.device_context(), in_t, out_t, + /*stride*/ context_stride, /*pad*/ pad, + /*row_start*/ row_start, + /*row_end*/ row_start + sequence_height); + if (padding_trainable) { + // add up trainable data + out_t.Resize(framework::make_ddim( + {sequence_height * context_length, sequence_width})); + if (up_pad != 0) { + for (int k = 0; k < up_pad; ++k) { + Tensor out_t_sub = out_t.Slice( + k * context_length, k * context_length + (up_pad - k)); + Tensor w_sub = padding_data->Slice(k, context_length - k); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + out_t_sub_e.device(place) = w_sub_e; + } + } + if (down_pad != 0) { + int k = + (sequence_height + up_pad - context_length) / context_stride + + 1; + for (int t = 0; t + k < sequence_height; ++t) { + Tensor out_t_sub = + out_t.Slice((k + t) * context_length * sequence_width - + t * sequence_width, + (k + t) * context_length * sequence_width); + Tensor w_sub = padding_data->Slice(up_pad + 1, up_pad + 1 + t); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + out_t_sub_e.device(place) = w_sub_e; + } + } + out_t.Resize(framework::make_ddim( + {sequence_height, context_length * sequence_width})); + } + } + } + } +}; + +template +class SequenceProjectGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // auto* in = context.Input("X"); + auto* out_g = context.Input(framework::GradVarName("Out")); + auto* in_g = context.Output(framework::GradVarName("X")); + in_g->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + + int context_start = context.Attr("context_start"); + int context_length = context.Attr("context_length"); + bool padding_trainable = context.Attr("padding_trainable"); + int context_stride = context.Attr("context_stride"); + + // InferShape by in_lod + PADDLE_ENFORCE_EQ(in_g->lod().size(), 1UL, + "Only support one level sequence now."); + auto lod_g_level_0 = in_g->lod()[0]; + int64_t input_width = in_g->dims()[1]; + int64_t output_width = out_g->dims()[1]; + int64_t padding_width = 0; + PADDLE_ENFORCE(input_width * context_length == output_width, + "Input size and pooling size should be consistent."); + + LoDTensor* padding_data = nullptr; + if (padding_trainable) { + padding_data = context.Output("PaddingData"); + padding_data->mutable_data(context.GetPlace()); + PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL, + "Only support one level sequence now."); + padding_width = padding_data->dims()[1]; + PADDLE_ENFORCE(padding_width == input_width, + "Input size and pooling size should be consistent."); + } + + int up_pad = std::max(0, -context_start); + int down_pad = std::max(0, context_start + context_length - 1); + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kOCF, Place, float> + col2im_ocf; + + for (int i = 0; i < static_cast(lod_g_level_0.size()) - 1; ++i) { + Tensor in_g_t = in_g->Slice(static_cast(lod_g_level_0[i]), + static_cast(lod_g_level_0[i + 1])); + Tensor out_g_t = out_g->Slice(static_cast(lod_g_level_0[i]), + static_cast(lod_g_level_0[i + 1])); + + int sequence_height = in_g_t.dims()[0]; + int sequence_width = in_g_t.dims()[1]; + + for (int j = 0; j < context_length; ++j) { + if (padding_trainable) { + out_g_t.Resize(framework::make_ddim( + {sequence_height * context_length, sequence_width})); + if (up_pad != 0) { + for (int k = 0; k < up_pad; ++k) { + Tensor out_t_sub = out_g_t.Slice( + k * context_length, k * context_length + (up_pad - k)); + Tensor w_sub = padding_data->Slice(k, context_length - k); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + w_sub_e.device(place) = w_sub_e + out_t_sub_e; + // out_t_sub_e.device(place) = 0; + } + } + if (down_pad != 0) { + int k = + (sequence_height + up_pad - context_length) / context_stride + + 1; + for (int t = 0; t + k < sequence_height; ++t) { + Tensor out_t_sub = + out_g_t.Slice((k + t) * context_length * sequence_width - + t * sequence_width, + (k + t) * context_length * sequence_width); + Tensor w_sub = padding_data->Slice(up_pad + 1, up_pad + 1 + t); + auto out_t_sub_e = EigenMatrix::From(out_t_sub); + auto w_sub_e = EigenMatrix::From(w_sub); + w_sub_e.device(place) = w_sub_e + out_t_sub_e; + // out_t_sub_e.device(place) = 0; + } + } + } + out_g_t.Resize(framework::make_ddim( + {sequence_height, 1, 1, context_length, sequence_width})); + + int pad; + int row_start; + + if (up_pad != 0) { + pad = up_pad; + row_start = 0; + } else if (down_pad != 0) { + pad = down_pad; + row_start = down_pad; + } else { + pad = 0; + row_start = 0; + } + col2im_ocf(context.device_context(), in_g_t, out_g_t, + /*stride*/ context_stride, /*pad*/ pad, + /*row_start*/ row_start, + /*row_end*/ row_start + sequence_height); + + // out_g_t back to orign size + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_seq_project.py b/python/paddle/v2/framework/tests/test_seq_project.py new file mode 100644 index 0000000000000000000000000000000000000000..57e01e414dbd462b1b23b5a1239d11a1fca6c880 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_project.py @@ -0,0 +1,96 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSeqProject(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = 'sequence_project' + # one level, batch size + x = np.random.uniform( + 0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32') + lod = [[0, 4, 5, 8, self.input_size[0]]] + + self.begin_pad = np.max([0, -self.context_start]) + self.end_pad = np.max([0, self.context_start + self.context_length - 1]) + self.total_pad = self.begin_pad + self.end_pad + w = np.ones((self.total_pad, self.input_size[1])) * 100 + + self.inputs = {'X': (x, lod), 'PaddingData': w} + self.attrs = { + 'context_start': self.context_start, + 'context_length': self.context_length, + 'padding_trainable': self.padding_trainable + } + out = np.zeros((self.input_size[0], self.input_size[1] * + self.context_length)).astype('float32') + self.outputs = {'Out': out} + self.compute() + + def compute(self): + x, lod = self.inputs['X'] + w = self.inputs['PaddingData'] + out = self.outputs['Out'] + lod = lod[0] + + for i in range(len(lod) - 1): + for j in range(self.context_length): + in_begin = lod[i] + self.context_start + j + in_end = lod[i + 1] + self.context_start + j + out_begin = lod[i] + out_end = lod[i + 1] + if in_begin < lod[i]: + pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]]) + if self.padding_trainable: + sub_w = w[j:pad_size, :] + out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:( + j + 1) * self.input_size[1]] = sub_w + # pass + out_begin = lod[i] + pad_size + in_begin = lod[i] + + if in_end > lod[i + 1]: + pad_size = np.min( + [in_end - lod[i + 1], lod[i + 1] - lod[i]]) + out_sub = out[lod[i + 1] - pad_size:lod[i + 1], :] + if self.padding_trainable: + sub_w = w[j - pad_size:j, :] + out[lod[i + 1] - pad_size:lod[i + 1], j * self. + input_size[1]:(j + 1) * self.input_size[1]] = sub_w + # pass + in_end = lod[i + 1] + out_end = lod[i + 1] - pad_size + if in_end <= in_begin: + continue + + in_sub = x[in_begin:in_end, :] + out[out_begin:out_end, j * self.input_size[1]:(j + 1) * + self.input_size[1]] += in_sub + + def init_test_case(self): + self.input_size = [11, 23] + self.op_type = "sequence_project" + + self.context_start = -1 + self.context_length = 3 + self.padding_trainable = False + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(["X"], "Out") + + # class TestSeqAvgPool2D(TestSeqProject): + # def init_test_case(self): + # self.input_size = [11, 23] + # self.op_type = "sequence_project" + # + # self.context_start = -1 + # self.context_length = 3 + # self.padding_trainable = True + + +if __name__ == '__main__': + unittest.main()