diff --git a/paddle/operators/row_conv_op.cc b/paddle/operators/row_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea0bb99f8d2e6b9a3b6eac7e298a89cd604ec49c --- /dev/null +++ b/paddle/operators/row_conv_op.cc @@ -0,0 +1,257 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/row_conv_op.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +class RowConvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of RowConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of RowConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of RowConvOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto filter_dims = ctx->GetInputDim("Filter"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); + PADDLE_ENFORCE_EQ( + x_dims[1], filter_dims[1], + "The 2nd dimension of Input(X) and Input(Filter) should be same."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class RowConvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of output(Out) should not be null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto filter_grad_name = framework::GradVarName("Filter"); + if (ctx->HasOutput(filter_grad_name)) { + auto filter_dims = ctx->GetInputDim("Filter"); + ctx->SetOutputDim(filter_grad_name, filter_dims); + } + } +}; + +class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RowConvOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor), the input(X) is a LodTensor, which supports " + "variable time-length input sequences. The underlying tensor " + "in this LoDTensor is a matrix with shape (T x N), where T " + "is the total time steps in this mini-batch and N is the input " + "data dimension."); + AddInput("Filter", + "(Tensor), the input(Filter) is a learnable parameter. It " + "is a 2-D tensor with shape (future_context x N), where, " + "future_context is the future context length and N is the data " + "dimension."); + AddOutput("Out", + "(LoDTensor), the output(Out) is a LodTensor, which supports " + "variable time-length input sequences. The underlying tensor " + "in this LodTensor is a matrix with shape T x N, i.e., the " + "same shape as X."); + AddComment(R"DOC( +Row-convolution Operator. + +The row convolution is called lookahead convolution. This operator was +introduced in the following paper for DeepSpeech2: +http://www.cs.cmu.edu/~dyogatam/papers/wang+etal.iclrworkshop2016.pdf + +The main motivation is that a bidirectional RNN, useful in DeepSpeech +like speech models, learns representation for a sequence by performing a +forward and a backward pass through the entire sequence. However, unlike +unidirectional RNNs, bidirectional RNNs are challenging to deploy in an online +and low-latency setting. The lookahead convolution incorporates information +from future subsequences in a computationally efficient manner to improve +unidirectional recurrent neural networks. The row convolution operator is +different from the 1D sequence convolution, and is computed as follows: + +Given an input sequence $in$ of length $t$ and input dimension $d$, +and a filter ($W$) of size $context \times d$, +the output sequence is convolved as: + +$$ +out_{i, :} = \sum_{j=i}^{i + context} in_{j,:} \dot W_{i-j, :} +$$ + +)DOC"); + } +}; + +template +class RowConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *x = context.Input("X"); + auto *filter = context.Input("Filter"); + auto *out = context.Output("Out"); + + out->mutable_data(context.GetPlace()); + + auto batch_indices = x->lod()[0]; + auto input_dim = x->dims()[1]; // 'in' is of size T x N + size_t num_sequence = batch_indices.size() - 1; + + auto future_context = filter->dims()[0]; + auto weights = EigenMatrix::From(*filter); + + for (size_t i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + Tensor cur_input_sequence = + x->Slice(start, end); // Current input sequence + Tensor cur_output_sequence = + out->Slice(start, end); // Current output sequence + auto cip_seq = EigenMatrix::From(cur_input_sequence); + auto cot_seq = EigenMatrix::From(cur_output_sequence); + + for (int k = 0; k < current_timesteps; + k++) { // For different time steps in the same sequence + for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); + w++) { + for (int d = 0; d < input_dim; d++) { + if (w == 0) { + cot_seq(k, d) = weights(w, d) * cip_seq(k + w, d); + } else { + cot_seq(k, d) += weights(w, d) * cip_seq(k + w, d); + } + } + } + } + } + } +}; + +template +class RowConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *x = context.Input("X"); + auto *filter = context.Input("Filter"); + auto *d_out = context.Input(framework::GradVarName("Out")); + auto *dx = context.Output(framework::GradVarName("X")); + auto *d_filter = context.Output(framework::GradVarName("Filter")); + + auto input_dim = x->dims()[1]; // 'x' is of size T x N + auto batch_indices = x->lod()[0]; + size_t num_sequence = batch_indices.size() - 1; + auto future_context = filter->dims()[0]; + + if (d_filter) { + d_filter->mutable_data(context.GetPlace()); + auto dweights = + EigenMatrix::From(*d_filter); // Gradient of weight matrix + dweights.setZero(); + + for (size_t i = 0; i < num_sequence; i++) { // For different sequences + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + + Tensor cur_input = x->Slice(start, end); // Current input sequence + Tensor cur_doutput = + d_out->Slice(start, end); // Current output grad sequence + + auto cur_ip = EigenMatrix::From(cur_input); + auto cur_dout = EigenMatrix::From(cur_doutput); + int current_timesteps = end - start; + + for (int k = 0; k < current_timesteps; + k++) { // For different time steps in the same sequence + for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); + w++) { + // For dweights (Updating the gradient of weight matrix) + for (int d = 0; d < input_dim; d++) { + dweights(w, d) += cur_ip(k + w, d) * cur_dout(k, d); + } + } + } + } + } + + if (dx) { + dx->mutable_data(context.GetPlace()); + auto weights = EigenMatrix::From(*filter); + for (size_t i = 0; i < num_sequence; i++) { // For different sequences + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + + Tensor cur_doutput = + d_out->Slice(start, end); // Current output grad sequence + Tensor cur_dinput = + dx->Slice(start, end); // Current input grad sequence + + auto cur_dout = EigenMatrix::From(cur_doutput); + auto cur_dip = EigenMatrix::From(cur_dinput); + cur_dip.setZero(); + int current_timesteps = end - start; + + for (int k = 0; k < current_timesteps; + k++) { // For different time steps in the same sequence + for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); + w++) { + // For dinput (Updating the gradient wrt input) + for (int d = 0; d < input_dim; d++) { + cur_dip(k + w, d) += weights(w, d) * cur_dout(k, d); + } + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(row_conv, ops::RowConvOp, ops::RowConvOpMaker, row_conv_grad, + ops::RowConvGradOp); +REGISTER_OP_CPU_KERNEL(row_conv, + ops::RowConvKernel); +REGISTER_OP_CPU_KERNEL( + row_conv_grad, ops::RowConvGradKernel); diff --git a/paddle/operators/row_conv_op.cu b/paddle/operators/row_conv_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e0d7ebda7e76919f2d9be702a5399793278ccc43 --- /dev/null +++ b/paddle/operators/row_conv_op.cu @@ -0,0 +1,408 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/row_conv_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using framework::Tensor; + +namespace { + +inline int DivUp(int x, int y) { return (x + y - 1) / y; } + +// Forward prop (shared memory version, for small future_context) +template +__global__ void RowConvForwardSharedMemory(const T *in, const T *wt, + int num_sequence, int input_dim, + int future_context, + const size_t *batch_indices, + T *out) { + int blx = blockDim.x; + int bly = blockDim.y; + int thx = threadIdx.x; + int thy = threadIdx.y; + int d = blockIdx.x * blx + thx; // index along input dim + + extern __shared__ T mem[]; + T *sw = mem; + + if (thy < future_context) { + sw[thy * blx + thx] = + (d < input_dim) ? wt[thy * input_dim + d] : static_cast(0); + } + __syncthreads(); + + for (size_t i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { + T sum = 0; + for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); + w++) { + sum += (d < input_dim) + ? sw[w * blx + thx] * in[(start + k + w) * input_dim + d] + : static_cast(0); + } + if (d < input_dim) { + out[(start + k) * input_dim + d] = sum; + } + } + } +} + +// Forward prop (naive version) +template +__global__ void RowConvForward(const T *in, const T *wt, int num_sequence, + int input_dim, int future_context, + const size_t *batch_indices, T *out) { + int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim + int bly = blockDim.y; + int thy = threadIdx.y; + + if (d >= input_dim) return; + + for (size_t i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { + T sum = 0; + for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); + w++) { + sum += (wt[w * input_dim + d] * in[(start + k + w) * input_dim + d]); + } + out[(start + k) * input_dim + d] = sum; + } + } +} + +// Compute input gradient (shared memory version, for small future_context) +template +__global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt, + int num_sequence, int input_dim, + int future_context, + const size_t *batch_indices, + T *din) { + int blx = blockDim.x; + int bly = blockDim.y; + int thx = threadIdx.x; + int thy = threadIdx.y; + int d = blockIdx.x * blx + thx; // index along input dim + + extern __shared__ T mem[]; + T *sw = mem; + if (thy < future_context) { + sw[thy * blx + thx] = + (d < input_dim) ? wt[thy * input_dim + d] : static_cast(0); + } + __syncthreads(); + + for (int i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { + T sum = 0; + for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { + sum += (d < input_dim) + ? (sw[w * blx + thx] * dout[(k + start - w) * input_dim + d]) + : static_cast(0); + } + if (d < input_dim) { + din[(k + start) * input_dim + d] = sum; + } + } + } +} + +// Compute input gradient (Naive version) +template +__global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence, + int input_dim, int future_context, + const size_t *batch_indices, T *din) { + int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim + int bly = blockDim.y; + int thy = threadIdx.y; + + if (d >= input_dim) return; + for (int i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + for (int k = thy; k < current_timesteps; k += bly) { + T sum = 0; + for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { + sum += (wt[w * input_dim + d] * dout[(k + start - w) * input_dim + d]); + } + din[(k + start) * input_dim + d] = sum; + } + } +} + +// Compute W gradient (small future_context version) +template +__global__ void RowConvGradFilterImproved(const T *in, const T *dout, + int num_sequence, int input_dim, + int future_context, int block_x, + int block_y, + const size_t *batch_indices, + T *dfilter) { + int blx = blockDim.x; + int bly = blockDim.y; + int thx = threadIdx.x; + int thy = threadIdx.y; + int gx = blockIdx.x * blx; + int d = gx + thx; // index along input dim + + extern __shared__ T mem[]; + + int xdim_sh_in = block_y; + int xdim_sh_dout = block_y; + // int xdim_sh_dfilter = future_context; + int ydim_sh_in = block_x; + int ydim_sh_dout = block_x + future_context - 1; + int ydim_sh_dfilter = block_y; + + T *sh_in = mem; + T *sh_dout = &mem[xdim_sh_in * ydim_sh_in]; + T *sh_dfilter = &mem[xdim_sh_in * ydim_sh_in + xdim_sh_dout * ydim_sh_dout]; + + if (thy < future_context) { + sh_dfilter[thy * ydim_sh_dfilter + thx] = static_cast(0); + } + __syncthreads(); + + for (int i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + int scaled_cur_steps = + ((current_timesteps + block_x - 1) / block_x) * block_x; + + for (int k = thy; k < scaled_cur_steps; k += block_x) { + int pos = start + k; + sh_in[thx * ydim_sh_in + thy] = + (d < input_dim && pos < end) ? in[pos * input_dim + d] : T(0); + sh_dout[thx * ydim_sh_dout + thy + future_context - 1] = + (d < input_dim && pos < end) ? dout[pos * input_dim + d] : T(0); + __syncthreads(); + + if (thy < future_context - 1) { + int pos_offset = pos - future_context + 1; + sh_dout[thx * ydim_sh_dout + thy] = + (d < input_dim && pos_offset >= start) + ? dout[pos_offset * input_dim + d] + : T(0); + } + __syncthreads(); + + for (int w = 0; w < future_context; w++) { + T val = sh_in[thy * ydim_sh_in + thx] * + sh_dout[thy * ydim_sh_dout + thx + future_context - 1 - w]; + __syncthreads(); + + for (int offset = 16; offset > 0; + offset = offset / 2) { // blockDim.x is 32. + val += __shfl_down(val, offset); + } + __syncthreads(); + + if (thx == 0) { + sh_dfilter[w * ydim_sh_dfilter + thy] += val; + } + __syncthreads(); + } + } + } + for (int w = thy; (w < future_context) && (d < input_dim); w += bly) { + dfilter[w * input_dim + d] += sh_dfilter[w * ydim_sh_dfilter + thx]; + } +} + +// Compute weight(filter) gradient +template +__global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, + int input_dim, int future_context, + int block_x, int block_y, + const size_t *batch_indices, T *dfilter) { + int blx = blockDim.x; + int bly = blockDim.y; + int thx = threadIdx.x; + int thy = threadIdx.y; + int gx = blockIdx.x * blx; + int d = gx + thx; // index along input dim + extern __shared__ T mem[]; + T *sh_in = mem; + T *sh_dout = &mem[block_x * block_y]; + + for (int i = 0; i < num_sequence; i++) { + int start = static_cast(batch_indices[i]); + int end = static_cast(batch_indices[i + 1]); + int current_timesteps = end - start; + int scaled_cur_steps = + ((current_timesteps + block_x - 1) / block_x) * block_x; + + for (int k = thy; k < scaled_cur_steps; k += block_x) { + int pos = start + k; + sh_in[thx * block_y + thy] = + (d < input_dim && pos < end) ? in[pos * input_dim + d] : 0.0; + __syncthreads(); + + for (int w = 0; w < future_context; w++) { + sh_dout[thx * block_y + thy] = + (d < input_dim && (k - w) >= 0 && (k - w) < current_timesteps) + ? dout[(pos - w) * input_dim + d] + : 0.0; + __syncthreads(); + + T val = sh_in[thy * block_y + thx] * sh_dout[thy * block_y + thx]; + __syncthreads(); + + for (int offset = 16; offset > 0; + offset = offset / 2) { // blockDim.x is 32. + val += __shfl_down(val, offset); + } + __syncthreads(); + + if (thx == 0 && (gx + thy) < input_dim) { + dfilter[w * input_dim + gx + thy] += val; + } + } + } + } +} + +} // namespace + +template +class RowConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Filter = context.Input("Filter"); + auto *Out = context.Output("Out"); + + const T *in = X->data(); + const T *weight = Filter->data(); + T *out = Out->mutable_data(context.GetPlace()); + + auto batch_indices = X->lod()[0]; + int input_dim = X->dims()[1]; + int num_sequence = batch_indices.size() - 1; + int future_context = Filter->dims()[0]; + size_t *idx = batch_indices.data(); + auto stream = context.cuda_device_context().stream(); + + if (future_context <= 32) { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + int mem_per_block = (future_context * block_dim.x) * sizeof(T); + RowConvForwardSharedMemory< + T><<>>( + in, weight, num_sequence, input_dim, future_context, idx, out); + } else { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + RowConvForward<<>>( + in, weight, num_sequence, input_dim, future_context, idx, out); + } + } +}; + +template +class RowConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Filter = context.Input("Filter"); + auto *dOut = context.Input(framework::GradVarName("Out")); + const T *in = X->data(); + const T *weights = Filter->data(); + const T *dout = dOut->data(); + + Tensor *dX = context.Output(framework::GradVarName("X")); + Tensor *dFilter = context.Output(framework::GradVarName("Filter")); + + auto batch_indices = X->lod()[0]; + int input_dim = X->dims()[1]; + int num_sequence = batch_indices.size() - 1; + int future_context = Filter->dims()[0]; + size_t *idx = batch_indices.data(); + + auto &device_ctx = context.cuda_device_context(); + math::SetConstant zero; + + if (dFilter) { + T *dfilter = dFilter->mutable_data(context.GetPlace()); + zero(device_ctx, dFilter, static_cast(0.0)); + + if (future_context <= 32) { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + int block_x = block_dim.x; + int block_y = block_dim.y; + int mem_per_block = + (block_y * block_x + block_y * (block_x + future_context - 1) + + future_context * block_y) * + sizeof(T); + RowConvGradFilterImproved< + T><<>>( + in, dout, num_sequence, input_dim, future_context, block_x, block_y, + idx, dfilter); + } else { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + int block_x = block_dim.x; + int block_y = block_dim.y; + int mem_per_block = + (block_x * block_y * 2) * sizeof(T); // For 2 arrays of size 32x32 + RowConvGradFilter< + T><<>>( + in, dout, num_sequence, input_dim, future_context, block_x, block_y, + idx, dfilter); + } + } + + if (dX) { + T *din = dX->mutable_data(context.GetPlace()); + if (future_context <= 32) { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + int mem_per_block = (future_context * block_dim.x) * sizeof(T); + RowConvGradInputSharedMemory< + T><<>>( + dout, weights, num_sequence, input_dim, future_context, idx, din); + } else { + dim3 block_dim = dim3(32, 32); + dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1); + RowConvGradInput<<>>( + dout, weights, num_sequence, input_dim, future_context, idx, din); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(row_conv, + ops::RowConvKernel); +REGISTER_OP_GPU_KERNEL( + row_conv_grad, ops::RowConvGradKernel); diff --git a/paddle/operators/row_conv_op.h b/paddle/operators/row_conv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..525e83908d86fb6ff0c4ee48257b39a836ea9f97 --- /dev/null +++ b/paddle/operators/row_conv_op.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RowConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; + +template +class RowConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_row_conv_op.py b/python/paddle/v2/fluid/tests/test_row_conv_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed86e23ac28a575cdc3388e9da547918eb8a1be --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_row_conv_op.py @@ -0,0 +1,95 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def row_conv_forward(x, lod, wt): + out = np.zeros_like(x) + seq_info = lod[0] + num_sequences = len(seq_info) - 1 + context_length = wt.shape[0] + + for i in range(num_sequences): # loop over number of sequences + start = seq_info[i] + end = seq_info[i + 1] + curinput = x[start:end, :] + curoutput = out[start:end, :] + + cur_timesteps = end - start + for j in range(cur_timesteps): # loop over different timesteps + for k in range(context_length): + + if j + k >= cur_timesteps: + continue + curoutput[j, :] += curinput[j + k, :] * wt[k, :] + + return out + + +class TestRowConvOp1(OpTest): + def setUp(self): + + self.op_type = "row_conv" + lod = [[0, 2, 5, 7]] + T = lod[0][-1] + D = 16 + context_length = 2 + + x = np.random.random((T, D)).astype("float32") + wt = np.random.random((context_length, D)).astype("float32") + self.inputs = {'X': (x, lod), 'Filter': wt} + + out = row_conv_forward(x, lod, wt) + self.outputs = {'Out': (out, lod)} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Filter'], 'Out', max_relative_error=0.05, no_grad_set=set('X')) + + def test_check_grad_ignore_wt(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Filter')) + + +class TestRowConvOp2(OpTest): + def setUp(self): + + self.op_type = "row_conv" + lod = [[0, 20, 50, 100]] + T = lod[0][-1] + D = 35 + context_length = 35 + + x = np.random.random((T, D)).astype("float32") + wt = np.random.random((context_length, D)).astype("float32") + self.inputs = {'X': (x, lod), 'Filter': wt} + + out = row_conv_forward(x, lod, wt) + self.outputs = {'Out': (out, lod)} + + def test_check_output(self): + self.check_output() + + #max_relative_error is increased from 0.05 to 0.06 as for higher + #dimensional input, the dX on CPU for some values has max_rel_error + #slightly more than 0.05 + def test_check_grad_normal(self): + self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.06) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Filter'], 'Out', max_relative_error=0.06, no_grad_set=set('X')) + + def test_check_grad_ignore_wt(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Filter')) + + +if __name__ == '__main__': + unittest.main()