diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 59bd8b91255944a8ef702edd389214c17d0cb35d..c7b017bc07b25bc606fd838a5fb9d3715f4faecb 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -59,6 +59,11 @@ context_projection .. autoclass:: paddle.v2.layer.context_projection :noindex: +row_conv +-------- +.. autoclass:: paddle.v2.layer.row_conv + :noindex: + Image Pooling Layer =================== @@ -346,6 +351,12 @@ sampling_id .. autoclass:: paddle.v2.layer.sampling_id :noindex: +multiplex +--------- +.. autoclass:: paddle.v2.layer.multiplex + :noindex: + + Slicing and Joining Layers ========================== diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 233a53709a80f06dd2a06995b159c1aef10e2788..1f54ac1231c6ac2e19b25bb336292194c63c11e9 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -28,6 +28,7 @@ if(WITH_TESTING) add_simple_unittest(PadOpTest) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) + add_simple_unittest(RowConvOpTest) endif() endif() diff --git a/paddle/function/RowConvOp.cpp b/paddle/function/RowConvOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6501e8f4db7fd33891cd80e07a6f36dd0b34532 --- /dev/null +++ b/paddle/function/RowConvOp.cpp @@ -0,0 +1,225 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "RowConvOp.h" +#include +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void RowConv(CpuMatrix& out, + const CpuMatrix& in, + const CpuMatrix& filter, + const CpuIVector& seq) { + const int* starts = seq.getData(); + const size_t numSeq = seq.getSize() - 1; + const size_t contextLength = filter.getHeight(); + for (size_t i = 0; i < numSeq; ++i) { + size_t begin = starts[i]; + size_t end = starts[i + 1]; + for (size_t j = begin; j < end; ++j) { + MatrixPtr x; + MatrixPtr w; + if ((j + contextLength) < end) { + x = (const_cast(in)).subMatrix(j, contextLength); + w = (const_cast(filter)).subMatrix(0, contextLength); + } else { + x = (const_cast(in)).subMatrix(j, end - j); + w = (const_cast(filter)).subMatrix(0, end - j); + } + MatrixPtr y = out.subMatrix(j, 1); + y->addDotMulVMM(*x, *w); + } + } +} + +template <> +void RowConvGrad(const CpuMatrix& outG, + const CpuMatrix& in, + const CpuMatrix& filter, + CpuMatrix& inG, + CpuMatrix& filterG, + const CpuIVector& seq) { + // gradient w.r.t filter + const int* starts = seq.getData(); + const size_t numSeq = seq.getSize() - 1; + const size_t contextLength = filter.getHeight(); + if (filterG) { + for (size_t i = 0; i < numSeq; ++i) { + size_t begin = starts[i]; + size_t end = starts[i + 1]; + size_t steps = end - begin; + for (size_t j = 0; j < contextLength && (begin + j) < end; ++j) { + MatrixPtr x = + (const_cast(in)).subMatrix(begin + j, steps - j); + MatrixPtr dy = + (const_cast(outG)).subMatrix(begin, steps - j); + MatrixPtr dw = filterG.subMatrix(j, 1); + dw->addDotMulVMM(*dy, *x); + } + } + } + + // gradient w.r.t input feature + if (inG) { + for (size_t i = 0; i < numSeq; ++i) { + size_t begin = starts[i]; + size_t end = starts[i + 1]; + size_t steps = end - begin; + for (size_t j = 0; j < steps; ++j) { + MatrixPtr dx = inG.subMatrix(begin + j, 1); + for (size_t t = 0; t < contextLength; ++t) { + if (int(j - t) >= 0) { + MatrixPtr dy = + (const_cast(outG)).subMatrix(begin + j - t, 1); + MatrixPtr w = (const_cast(filter)).subMatrix(t, 1); + dx->addDotMul(*dy, *w, 1.0, 1.0); + } + } + } + } + } +} + +/** + * \brief The row convolution is called lookahead convolution. It is firstly + * introduced in deep-speech2 system. The bidirectional RNN that learns + * representation for a sequence by performing a forward and a backward pass + * through the entire sequence. However, unlike unidirectional RNNs, + * bidirectional RNNs are challenging to deploy in an online and low-latency + * setting. The lookahead convolution incorporates information from future + * subsequences in a computationally efficient manner to improve unidirectional + * recurrent neural networks. + * + * The connection of row convolution is different form the 1D sequence + * convolution. Assumed that, the future context-length is k, that is to say, + * it can get the output at timestep t by using the the input feature from t-th + * timestep to (t+k)-th timestep. Assumed that the hidden dim of input + * activations are d, the activations r_t for the new layer at time-step t are: + * + * + * -- k + 1 + * r(t,i) = > W(i,j) * h(t+j-1, i), for (1 <= i <= d) + * -- j = 1 + * + * + * The weight shape is: (k + 1) x d + * Function Arguments: + * + * \param inputs[0] The input activations. + * \param inputs[0] The filter (or weight) and shape is (k+1) x d. + * \param outputs[1] The output activations. + * + * [1] Dario Amodei, etc. Deep Speech 2 : End-to-End Speech Recognition in + * English + * and Mandarin. https://arxiv.org/abs/1512.02595 + */ + +template +class RowConvFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + // check + CHECK_EQ(2UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + // TODO(qingqing): support ASSIGN_TO. + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) + << "SequenceArg required here."; + const auto in = dynamic_cast(inputs[0]); + auto out = dynamic_cast(outputs[0]); + auto w = inputs[1]; + CHECK(in.data() && out.data() && in.getSequenceId().data()); + CHECK_EQ(in.shape().ndims(), 2UL); + CHECK(in.shape() == out.shape()); + CHECK_EQ(w.shape()[1], in.shape()[1]); + + auto outMat = out.matrix(); + const auto inMat = in.matrix(); + const auto wMat = w.matrix(); + const auto seqId = in.getSequenceId().vector(); + + RowConv(outMat, inMat, wMat, seqId); + } +}; + +/** + * \brief The backward of row convolution function. This function calculated + * the gradient w.r.t filter and the gradient w.r.t input activations(or data). + * + * Argument in this Function: + * + * \param inputs[0] The gradient w.r.t output activations. + * \param inputs[1] The input activations. + * \param inputs[2] The filter (or weight) and shape is (k+1) x d. + * \param outputs[0] The gradient w.r.t input activations. + * \param outputs[1] The gradient w.r.r filter. + * + * Abbreviation: + * w.r.t: with respect to. + */ + +template +class RowConvGradFunc : public FunctionBase { + // TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + // check + CHECK_EQ(3UL, inputs.size()); + CHECK_EQ(2UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + CHECK(inputs[0].isSequenceArg() && inputs[1].isSequenceArg() && + outputs[0].isSequenceArg()) + << "SequenceArg required here."; + + const auto outGrad = dynamic_cast(inputs[0]); + const auto in = dynamic_cast(inputs[1]); + const auto w = inputs[2]; + auto inGrad = dynamic_cast(outputs[0]); + auto wGrad = outputs[1]; + + CHECK_EQ(in.shape().ndims(), 2UL); + CHECK(in.shape() == inGrad.shape()); + CHECK(in.shape() == outGrad.shape()); + CHECK_EQ(wGrad.shape()[1], in.shape()[1]); + + const auto outGMat = outGrad.matrix(); + const auto inMat = in.matrix(); + const auto wMat = w.matrix(); + auto inGMat = inGrad.data() + ? inGrad.matrix() + : typename Tensor::Matrix(nullptr, 0, 0); + auto wGMat = wGrad.data() + ? wGrad.matrix() + : typename Tensor::Matrix(nullptr, 0, 0); + const auto seqId = in.getSequenceId().vector(); + + RowConvGrad(outGMat, inMat, wMat, inGMat, wGMat, seqId); + } +}; + +REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc); +REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc); +REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/RowConvOp.h b/paddle/function/RowConvOp.h new file mode 100644 index 0000000000000000000000000000000000000000..2c5de6151aa2ed73ace1569d880b1c3677c195c4 --- /dev/null +++ b/paddle/function/RowConvOp.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief The forward of row convolution. + * + * \param[out] out The output data and shape is h x d. h is the sum of + * time steps of all samples in one mini-batch. + * \param[in] in The input data and shape is h x d. + * \param[in] filter The filter and shape is k x d. The lookahead step + * number plus one equals k. + * \param[in] seq The sequence start positions. + * + */ +template +void RowConv(typename Tensor::Matrix& out, + const typename Tensor::Matrix& in, + const typename Tensor::Matrix& filter, + const typename Tensor::Vector& seq); + +/** + * \brief The backward of row convolution. + * + * \param[in] outG The gradient w.r.t output data. + * \param[in] in The input data. + * \param[in] filter The filter. + * \param[out] inG The gradient w.r.t input data. + * \param[out] filterG The gradient w.r.t filter. + * \param[in] seq The sequence start positions. + * + */ +template +void RowConvGrad(const typename Tensor::Matrix& outG, + const typename Tensor::Matrix& in, + const typename Tensor::Matrix& filter, + typename Tensor::Matrix& inG, + typename Tensor::Matrix& filterG, + const typename Tensor::Vector& seq); +} // namespace paddle diff --git a/paddle/function/RowConvOpGpu.cu b/paddle/function/RowConvOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..c0b947e224313abaf4fadfb8293dc78ca085ff84 --- /dev/null +++ b/paddle/function/RowConvOpGpu.cu @@ -0,0 +1,344 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "RowConvOp.h" + +namespace paddle { + +template +__global__ void KeRowConv(real* y, const real* x, const real* w, + const int* starts, const int height, const int width, + const int numSeq, const int context) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int blky = blockDim.y; + const int gidx = blockIdx.x * blockDim.x; + + __shared__ real sw[BLOCK_H][BLOCK_W]; + + for (int i = tidy; i < context; i += blky) { + sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; + } + + __syncthreads(); + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + for (int j = tidy; j < steps; j += blky) { + real sum = 0; + int off = (start + j) * width; + for (int t = 0; t < context; ++t) { + if ((start + j + t) < end) { + int xoff = off + t * width; + real xVal = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0; + sum += sw[t][tidx] * xVal; + } + } + if (gidx + tidx < width) { + y[off + gidx + tidx] += sum; + } + } + } +} + +__global__ void KeRowConv2(real* y, const real* x, const real* w, + const int* starts, const int height, const int width, + const int numSeq, const int context) { + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int blky = blockDim.y; + const int gidx = blockIdx.x * blockDim.x; + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + for (int j = tidy; j < steps; j += blky) { + int off = (start + j) * width; + real sum = 0; + for (int t = 0; t < context && (start + j + t) < end; ++t) { + int xoff = off + t * width; + real xd = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0; + real wd = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0; + sum += wd * xd; + } + if (gidx + tidx < width) { + y[off + gidx + tidx] += sum; + } + } + } +} + + + +template <> +void RowConv(GpuMatrix& out, + const GpuMatrix& in, + const GpuMatrix& filter, + const GpuIVector& seq) { + const size_t numSeq = seq.getSize() - 1; + const size_t contextLength = filter.getHeight(); + const size_t height = in.getHeight(); + const size_t width = in.getWidth(); + + real* y = out.getData(); + const real* x = in.getData(); + const real* w = filter.getData(); + const int* starts = seq.getData(); + + dim3 dimBlock(32, 32); + dim3 dimGrid(DIVUP(width, dimBlock.x), 1); + + if (contextLength <= 32) { + KeRowConv<32, 32><<>> + (y, x, w, starts, height, width, numSeq, contextLength); + } else { + KeRowConv2<<>> + (y, x, w, starts, height, width, numSeq, contextLength); + } + CHECK_SYNC("RowConv"); +} + + +template +__global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, + const int* starts, const int height, const int width, const int numSeq, + const int context) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int blky = blockDim.y; + const int gidx = blockIdx.x * blockDim.x; + + __shared__ real sh_x[BLOCK_W][BLOCK_H]; + __shared__ real sh_dy[BLOCK_W][BLOCK_H + CONTEXT - 1]; + __shared__ real sh_dw[CONTEXT][BLOCK_W]; + + if (tidy < context) { + sh_dw[tidy][tidx] = 0.0; + } + __syncthreads(); + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H; + for (int j = tidy; j < size; j += BLOCK_H) { + int xoff = gidx + tidx; + int yoff = start + j; + + // transpose + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0; + __syncthreads(); + if (tidy < (context - 1)) { + yoff = yoff - context + 1; + sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0; + } + __syncthreads(); + + for (int t = 0; t < context; t++) { + real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t]; + __syncthreads(); + // warp size and blockDim.x is 32. + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + __syncthreads(); + if (tidx == 0) { + sh_dw[t][tidy] += val; + } + __syncthreads(); + } + } + } + + for (int t = tidy; (t < context) && ((gidx + tidx) < width); t += blky) { + dw[t * width + gidx + tidx] += sh_dw[t][tidx]; + } +} + +template +__global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy, + const int* starts, const int height, const int width, const int numSeq, + const int context) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int gidx = blockIdx.x * blockDim.x; + + __shared__ real sh_x[BLOCK_H][BLOCK_W]; + __shared__ real sh_dy[BLOCK_H][BLOCK_W]; + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + + const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H; + for (int j = tidy; j < size; j += BLOCK_H) { + int xoff = gidx + tidx; + int yoff = start + j; + + // transpose + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; + __syncthreads(); + + for (int t = 0; t < context; t++) { + sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; + __syncthreads(); + + real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; + __syncthreads(); + // warp size and blockDim.x is 32. + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + __syncthreads(); + + if (tidx == 0 && (gidx + tidy) < width) { + dw[t*width + gidx + tidy] += val; + } + } + } + } +} + +template +__global__ void KeRowConvBwData(real* dx, const real* w, const real* dy, + const int* starts, const int height, const int width, const int numSeq, + const int context) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int blky = blockDim.y; + const int gidx = blockIdx.x * blockDim.x; + + __shared__ real sw[BLOCK_H][BLOCK_W]; + + for (int i = tidy; i < context; i += blky) { + sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; + } + + __syncthreads(); + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + for (int j = tidy; j < steps; j += blky) { + real sum = 0; + int off = (start + j) * width; + for (int t = 0; t < context && (j - t) >= 0; ++t) { + int dyOff = off - t * width; + real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0; + sum += sw[t][tidx] * dyVal; + } + if (gidx + tidx < width) { + dx[off + gidx + tidx] += sum; + } + } + } +} + +__global__ void KeRowConvBwData2(real* dx, const real* w, const real* dy, + const int* starts, const int height, const int width, const int numSeq, + const int context) { + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int blky = blockDim.y; + const int gidx = blockIdx.x * blockDim.x; + + for (int i = 0; i < numSeq; ++i) { + const int start = starts[i]; + const int end = starts[i + 1]; + const int steps = end - start; + for (int j = tidy; j < steps; j += blky) { + real sum = 0; + int off = (start + j) * width; + for (int t = 0; t < context && (j - t) >= 0; ++t) { + int dyOff = off - t * width; + real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0; + real wVal = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0; + sum += wVal * dyVal; + } + if (gidx + tidx < width) { + dx[off + gidx + tidx] += sum; + } + } + } +} + + +template <> +void RowConvGrad(const GpuMatrix& outG, + const GpuMatrix& in, + const GpuMatrix& filter, + GpuMatrix& inG, + GpuMatrix& filterG, + const GpuIVector& seq) { + const size_t numSeq = seq.getSize() - 1; + const size_t contextLength = filter.getHeight(); + const size_t height = in.getHeight(); + const size_t width = in.getWidth(); + + const real* dy = outG.getData(); + const real* x = in.getData(); + const real* w = filter.getData(); + const int* starts = seq.getData(); + + if (filterG) { + dim3 dimBlock(32, 32); + dim3 dimGrid(DIVUP(width, dimBlock.x), 1); + real* dw = filterG.getData(); + if (contextLength <= 32) { + KeRowConvBwWeight<32, 32, 32> + <<>> + (dw, x, dy, starts, height, width, numSeq, contextLength); + } else { + KeRowConvBwWeight2<32, 32> + <<>> + (dw, x, dy, starts, height, width, numSeq, contextLength); + } + } + + if (inG) { + real* dx = inG.getData(); + dim3 dimBlock2(32, 32); + dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1); + if (contextLength <= 64) { + KeRowConvBwData<32, 64> + <<>> + (dx, w, dy, starts, height, width, numSeq, contextLength); + } else { + KeRowConvBwData2 + <<>> + (dx, w, dy, starts, height, width, numSeq, contextLength); + } + } + + CHECK_SYNC("RowConvGrad"); +} + +} // namespace paddle diff --git a/paddle/function/RowConvOpTest.cpp b/paddle/function/RowConvOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c95d3ff2cccbf33f4c5f91f6daf340871a8f7b0 --- /dev/null +++ b/paddle/function/RowConvOpTest.cpp @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { + FunctionCompare test("RowConv", FuncConfig()); + + test.addSequence(SequenceIdArg(TensorShape{batchSize})); + test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim})); + + test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}), + ADD_TO); + + test.run(); +} + +void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) { + FunctionCompare test("RowConvGrad", FuncConfig()); + + test.addSequence(SequenceIdArg(TensorShape{batchSize})); + test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); + test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim})); + + test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}), + ADD_TO); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}), + ADD_TO); + + test.run(); +} + +TEST(RowConv, real) { + for (size_t numSamples : {17, 129, 2020}) { + for (size_t dim : {16, 512, 2560}) { + for (size_t context : {3, 19, 65}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim + << " context length=" << context; + testRowConvFw(numSamples, dim, context); + testRowConvBw(numSamples, dim, context); + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/RowConvLayer.cpp b/paddle/gserver/layers/RowConvLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54d77999ad5b30a8d9f4feaa02d81417957544a7 --- /dev/null +++ b/paddle/gserver/layers/RowConvLayer.cpp @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "RowConvLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(row_conv, RowConvLayer); + +bool RowConvLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + contexLength_ = config_.inputs(0).row_conv_conf().context_length(); + + CHECK_EQ(inputLayers_.size(), 1UL); + weight_.reset(new Weight(contexLength_, getSize(), parameters_[0])); + createFunction(forward_, "RowConv", FuncConfig()); + createFunction(backward_, "RowConvGrad", FuncConfig()); + + return true; +} + +void RowConvLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr input = getInputValue(0); + size_t height = input->getHeight(); + size_t width = input->getWidth(); + CHECK_EQ(width, getSize()); + resetOutput(height, width); + + const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_); + MatrixPtr w = weight_->getW(); + wDims_ = TensorShape({w->getHeight(), w->getWidth()}); + + MatrixPtr outV = getOutputValue(); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), *startPos); + inputs.addArg(*w, wDims_); + outputs.addArg(*getOutputValue(), *startPos, ADD_TO); + + { + REGISTER_TIMER_INFO("RowConvForward", getName().c_str()); + forward_[0]->calc(inputs, outputs); + } + + /* activation */ { + REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str()); + forwardActivation(); + } +} + +void RowConvLayer::backward(const UpdateCallback& callback) { + /* Do derivation */ { + REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str()); + backwardActivation(); + } + + const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), *startPos); + inputs.addArg(*getInputValue(0), *startPos); + inputs.addArg(*weight_->getW(), wDims_); + + MatrixPtr inGrad = getInputGrad(0); + MatrixPtr wGrad = weight_->getWGrad(); + size_t h = getInputValue(0)->getHeight(); + size_t w = getInputValue(0)->getWidth(); + outputs.addArg( + inGrad ? (*inGrad) : *(Matrix::create(nullptr, h, w, false, useGpu_)), + *startPos, + ADD_TO); + outputs.addArg( + wGrad ? (*wGrad) + : *(Matrix::create(nullptr, contexLength_, w, false, useGpu_)), + wDims_, + ADD_TO); + + { + REGISTER_TIMER_INFO("RowConvBackward", getName().c_str()); + backward_[0]->calc(inputs, outputs); + } + + { + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weight_->getParameterPtr()->incUpdate(callback); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/RowConvLayer.h b/paddle/gserver/layers/RowConvLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..b3bdda2f3542b312eda531695c975ff2dfc29c93 --- /dev/null +++ b/paddle/gserver/layers/RowConvLayer.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief Row Convolution Layer. + */ +class RowConvLayer : public Layer { +public: + explicit RowConvLayer(const LayerConfig& config) : Layer(config) {} + + ~RowConvLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + // Row convolution weight, context_lenght_ * fan_out. + // fan_out is the size of output feature. + std::unique_ptr weight_; + + // The step number to look ahead plus one equals contexLength_. + size_t contexLength_; + TensorShape wDims_; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index e1e8e7fae7ca4c96206d60703db1f35aa1196875..6adffcf53b7966bd6f3d02970e5f07cc9802f469 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1705,6 +1705,26 @@ TEST(Layer, TransLayer) { } } +TEST(Layer, RowConvLayer) { + const int context = 3; + const int size = 512; + + TestConfig config; + config.layerConfig.set_type("row_conv"); + config.layerConfig.set_size(size); + config.layerConfig.set_active_type("sigmoid"); + + config.inputDefs.push_back( + {INPUT_SEQUENCE_DATA, "layer_0", size, context * size}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + RowConvConfig* conv = input->mutable_row_conv_conf(); + conv->set_context_length(context); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "row_conv", 100, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 4f9b53d6f6553e55406dd000029a598a92fd2fb6..29270829bbc3af6990aaf03a5228ef7f6a892a5c 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -194,6 +194,10 @@ message MaxOutConfig { required uint32 groups = 2; } +message RowConvConfig { + required uint32 context_length = 1; +} + message ProjectionConfig { required string type = 1; required string name = 2; @@ -279,6 +283,7 @@ message LayerInputConfig { optional SppConfig spp_conf = 12; optional PriorBoxConfig priorbox_conf = 13; optional PadConfig pad_conf = 14; + optional RowConvConfig row_conv_conf = 15; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c5ec5349cf0c097d7bdbb50bd6d6d3568ff32400..0792e2d40b43f5fb2de8d6bb43a62cfa23f77082 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2081,6 +2081,23 @@ class MaxOutLayer(LayerBase): g_layer_map[input_layer.name].width, out_channels) +@config_layer('row_conv') +class RowConvLayer(LayerBase): + def __init__(self, name, inputs, context_length, **xargs): + super(RowConvLayer, self).__init__( + name, 'maxout', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'TransLayer must have one and only one input') + input_layer = self.get_input_layer(0) + row_conv_conf = self.config.inputs[0].row_conv_conf + row_conv_conf.context_length = context_length + self.set_layer_size(input_layer.size) + psize = context_length * input_layer.size + dims = [context_length, input_layer.size] + self.create_input_parameter(0, psize, dims) + + # key: cost type # value: cost class g_cost_map = {} diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5320f5c32ce00f4780cea16abaee718c95707467..2d8ddbb9007b241eb1986887d8ea6c2de8235c29 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -121,6 +121,7 @@ __all__ = [ 'smooth_l1_cost', 'layer_support', 'multiplex_layer', + 'row_conv_layer', 'dropout_layer', 'prelu_layer', ] @@ -179,17 +180,18 @@ class LayerType(object): EOSID_LAYER = 'eos_id' RECURRENT_LAYER = 'recurrent' - CONV_SHIFT_LAYER = 'conv_shift' - TENSOR_LAYER = 'tensor' - SEL_FC_LAYER = 'selective_fc' - SAMPLING_ID_LAYER = 'sampling_id' - SLOPE_INTERCEPT_LAYER = 'slope_intercept' - LINEAR_COMBINATION_LAYER = 'convex_comb' - BLOCK_EXPAND = 'blockexpand' - MAXOUT = 'maxout' - SPP_LAYER = 'spp' - PAD_LAYER = 'pad' - MULTIPLEX_LAYER = 'multiplex' + CONV_SHIFT_LAYER = "conv_shift" + TENSOR_LAYER = "tensor" + SEL_FC_LAYER = "selective_fc" + SAMPLING_ID_LAYER = "sampling_id" + SLOPE_INTERCEPT_LAYER = "slope_intercept" + LINEAR_COMBINATION_LAYER = "convex_comb" + BLOCK_EXPAND = "blockexpand" + MAXOUT = "maxout" + SPP_LAYER = "spp" + PAD_LAYER = "pad" + MULTIPLEX_LAYER = "multiplex" + ROW_CONV_LAYER = "row_conv" PRINT_LAYER = 'print' PRIORBOX_LAYER = 'priorbox' @@ -5585,6 +5587,79 @@ def dropout_layer(input, dropout_rate, name=None): @wrap_name_default() +@wrap_act_default(act=LinearActivation()) +@wrap_param_attr_default() +@layer_support(DROPOUT) +def row_conv_layer(input, + context_len, + act=None, + name=None, + param_attr=None, + layer_attr=None): + """ + + The row convolution is called lookahead convolution. It is firstly + introduced in paper of `Deep Speech 2: End-toEnd Speech Recognition + in English and Mandarin `_ . + + The bidirectional RNN that learns representation for a sequence by + performing a forward and a backward pass through the entire sequence. + However, unlike unidirectional RNNs, bidirectional RNNs are challenging + to deploy in an online and low-latency setting. The lookahead convolution + incorporates information from future subsequences in a computationally + efficient manner to improve unidirectional recurrent neural networks. + + The connection of row convolution is different form the 1D sequence + convolution. Assumed that, the future context-length is k, that is to say, + it can get the output at timestep t by using the the input feature from t-th + timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input + activations are d, the activations r_t for the new layer at time-step t are: + + .. math:: + + r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}} + \quad \text{for} \quad (1 \leq i \leq d) + + Note: + The `context_len` is `k + 1`. That is to say, the lookahead step + number plus one equals context_len. + + + .. code-block:: python + + row_conv = row_conv_layer(input=input_layer, context_len=3) + + + :param input: The input layer. + :type input: LayerOutput + :param context_len: The context length equals the lookahead step number + plus one. + :type context_len: int + :param act: Activation Type. Default is linear activation. + :type act: BaseActivation + :param param_attr: The Parameter Attribute. If None, the parameter will be + initialized smartly. It's better set it by yourself. + :type param_attr: ParameterAttribute + :param layer_attr: Extra Layer config. + :type layer_attr: ExtraLayerAttribute|None + :return: LayerOutput object. + :rtype: LayerOutput + + """ + assert isinstance(input, LayerOutput) + assert context_len > 0, "the context_len must be greatet than 0." + + Layer( + inputs=[Input(input.name, **param_attr.attr)], + name=name, + context_length=context_len, + type=LayerType.ROW_CONV_LAYER, + active_type=act.name, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, LayerType.ROW_CONV_LAYER, input, activation=act, size=input.size) + + @layer_support() @wrap_name_default() @wrap_param_attr_default() diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index bef14bffaf648b92e608a6a18cd46d57e850550e..c24102255f5bbed0f551b2dbfec20be7daf5f5b4 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -6,6 +6,6 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer -test_prelu_layer) +test_prelu_layer test_row_conv) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr new file mode 100644 index 0000000000000000000000000000000000000000..9ec15d2a19ec50a1729f9eeaa6dce8b1153c776b --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr @@ -0,0 +1,41 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 2560 + active_type: "" +} +layers { + name: "__row_conv_layer_0__" + type: "maxout" + size: 2560 + active_type: "relu" + inputs { + input_layer_name: "data" + input_parameter_name: "___row_conv_layer_0__.w0" + row_conv_conf { + context_length: 19 + } + } +} +parameters { + name: "___row_conv_layer_0__.w0" + size: 48640 + initial_mean: 0.0 + initial_std: 0.229415733871 + dims: 19 + dims: 2560 + initial_strategy: 0 + initial_smart: true +} +input_layer_names: "data" +output_layer_names: "__row_conv_layer_0__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__row_conv_layer_0__" + input_layer_names: "data" + output_layer_names: "__row_conv_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py b/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ab33c496b0663d8472ce4b272be6c5cecbcfc978 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py @@ -0,0 +1,9 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2560) + +row_conv = row_conv_layer(input=data, context_len=19, act=ReluActivation()) + +outputs(row_conv)