提交 1b8d2e65 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #2373 from qingqing01/row_conv

Row convolution operation.
...@@ -59,6 +59,11 @@ context_projection ...@@ -59,6 +59,11 @@ context_projection
.. autoclass:: paddle.v2.layer.context_projection .. autoclass:: paddle.v2.layer.context_projection
:noindex: :noindex:
row_conv
--------
.. autoclass:: paddle.v2.layer.row_conv
:noindex:
Image Pooling Layer Image Pooling Layer
=================== ===================
...@@ -346,6 +351,12 @@ sampling_id ...@@ -346,6 +351,12 @@ sampling_id
.. autoclass:: paddle.v2.layer.sampling_id .. autoclass:: paddle.v2.layer.sampling_id
:noindex: :noindex:
multiplex
---------
.. autoclass:: paddle.v2.layer.multiplex
:noindex:
Slicing and Joining Layers Slicing and Joining Layers
========================== ==========================
......
...@@ -28,6 +28,7 @@ if(WITH_TESTING) ...@@ -28,6 +28,7 @@ if(WITH_TESTING)
add_simple_unittest(PadOpTest) add_simple_unittest(PadOpTest)
add_simple_unittest(MulOpTest) add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest) add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest)
endif() endif()
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "RowConvOp.h"
#include <iostream>
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void RowConv<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuMatrix& in,
const CpuMatrix& filter,
const CpuIVector& seq) {
const int* starts = seq.getData();
const size_t numSeq = seq.getSize() - 1;
const size_t contextLength = filter.getHeight();
for (size_t i = 0; i < numSeq; ++i) {
size_t begin = starts[i];
size_t end = starts[i + 1];
for (size_t j = begin; j < end; ++j) {
MatrixPtr x;
MatrixPtr w;
if ((j + contextLength) < end) {
x = (const_cast<CpuMatrix&>(in)).subMatrix(j, contextLength);
w = (const_cast<CpuMatrix&>(filter)).subMatrix(0, contextLength);
} else {
x = (const_cast<CpuMatrix&>(in)).subMatrix(j, end - j);
w = (const_cast<CpuMatrix&>(filter)).subMatrix(0, end - j);
}
MatrixPtr y = out.subMatrix(j, 1);
y->addDotMulVMM(*x, *w);
}
}
}
template <>
void RowConvGrad<DEVICE_TYPE_CPU>(const CpuMatrix& outG,
const CpuMatrix& in,
const CpuMatrix& filter,
CpuMatrix& inG,
CpuMatrix& filterG,
const CpuIVector& seq) {
// gradient w.r.t filter
const int* starts = seq.getData();
const size_t numSeq = seq.getSize() - 1;
const size_t contextLength = filter.getHeight();
if (filterG) {
for (size_t i = 0; i < numSeq; ++i) {
size_t begin = starts[i];
size_t end = starts[i + 1];
size_t steps = end - begin;
for (size_t j = 0; j < contextLength && (begin + j) < end; ++j) {
MatrixPtr x =
(const_cast<CpuMatrix&>(in)).subMatrix(begin + j, steps - j);
MatrixPtr dy =
(const_cast<CpuMatrix&>(outG)).subMatrix(begin, steps - j);
MatrixPtr dw = filterG.subMatrix(j, 1);
dw->addDotMulVMM(*dy, *x);
}
}
}
// gradient w.r.t input feature
if (inG) {
for (size_t i = 0; i < numSeq; ++i) {
size_t begin = starts[i];
size_t end = starts[i + 1];
size_t steps = end - begin;
for (size_t j = 0; j < steps; ++j) {
MatrixPtr dx = inG.subMatrix(begin + j, 1);
for (size_t t = 0; t < contextLength; ++t) {
if (int(j - t) >= 0) {
MatrixPtr dy =
(const_cast<CpuMatrix&>(outG)).subMatrix(begin + j - t, 1);
MatrixPtr w = (const_cast<CpuMatrix&>(filter)).subMatrix(t, 1);
dx->addDotMul(*dy, *w, 1.0, 1.0);
}
}
}
}
}
}
/**
* \brief The row convolution is called lookahead convolution. It is firstly
* introduced in deep-speech2 system. The bidirectional RNN that learns
* representation for a sequence by performing a forward and a backward pass
* through the entire sequence. However, unlike unidirectional RNNs,
* bidirectional RNNs are challenging to deploy in an online and low-latency
* setting. The lookahead convolution incorporates information from future
* subsequences in a computationally efficient manner to improve unidirectional
* recurrent neural networks.
*
* The connection of row convolution is different form the 1D sequence
* convolution. Assumed that, the future context-length is k, that is to say,
* it can get the output at timestep t by using the the input feature from t-th
* timestep to (t+k)-th timestep. Assumed that the hidden dim of input
* activations are d, the activations r_t for the new layer at time-step t are:
*
*
* -- k + 1
* r(t,i) = > W(i,j) * h(t+j-1, i), for (1 <= i <= d)
* -- j = 1
*
*
* The weight shape is: (k + 1) x d
* Function Arguments:
*
* \param inputs[0] The input activations.
* \param inputs[0] The filter (or weight) and shape is (k+1) x d.
* \param outputs[1] The output activations.
*
* [1] Dario Amodei, etc. Deep Speech 2 : End-to-End Speech Recognition in
* English
* and Mandarin. https://arxiv.org/abs/1512.02595
*/
template <DeviceType Device>
class RowConvFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
// check
CHECK_EQ(2UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
// TODO(qingqing): support ASSIGN_TO.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here.";
const auto in = dynamic_cast<const SequenceArg&>(inputs[0]);
auto out = dynamic_cast<const SequenceArg&>(outputs[0]);
auto w = inputs[1];
CHECK(in.data() && out.data() && in.getSequenceId().data());
CHECK_EQ(in.shape().ndims(), 2UL);
CHECK(in.shape() == out.shape());
CHECK_EQ(w.shape()[1], in.shape()[1]);
auto outMat = out.matrix<Device>();
const auto inMat = in.matrix<Device>();
const auto wMat = w.matrix<Device>();
const auto seqId = in.getSequenceId().vector<int, Device>();
RowConv<Device>(outMat, inMat, wMat, seqId);
}
};
/**
* \brief The backward of row convolution function. This function calculated
* the gradient w.r.t filter and the gradient w.r.t input activations(or data).
*
* Argument in this Function:
*
* \param inputs[0] The gradient w.r.t output activations.
* \param inputs[1] The input activations.
* \param inputs[2] The filter (or weight) and shape is (k+1) x d.
* \param outputs[0] The gradient w.r.t input activations.
* \param outputs[1] The gradient w.r.r filter.
*
* Abbreviation:
* w.r.t: with respect to.
*/
template <DeviceType Device>
class RowConvGradFunc : public FunctionBase {
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
public:
void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
// check
CHECK_EQ(3UL, inputs.size());
CHECK_EQ(2UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
CHECK(inputs[0].isSequenceArg() && inputs[1].isSequenceArg() &&
outputs[0].isSequenceArg())
<< "SequenceArg required here.";
const auto outGrad = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto in = dynamic_cast<const SequenceArg&>(inputs[1]);
const auto w = inputs[2];
auto inGrad = dynamic_cast<const SequenceArg&>(outputs[0]);
auto wGrad = outputs[1];
CHECK_EQ(in.shape().ndims(), 2UL);
CHECK(in.shape() == inGrad.shape());
CHECK(in.shape() == outGrad.shape());
CHECK_EQ(wGrad.shape()[1], in.shape()[1]);
const auto outGMat = outGrad.matrix<Device>();
const auto inMat = in.matrix<Device>();
const auto wMat = w.matrix<Device>();
auto inGMat = inGrad.data()
? inGrad.matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
auto wGMat = wGrad.data()
? wGrad.matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seqId = in.getSequenceId().vector<int, Device>();
RowConvGrad<Device>(outGMat, inMat, wMat, inGMat, wGMat, seqId);
}
};
REGISTER_TYPED_FUNC(RowConv, CPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, CPU, RowConvGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(RowConv, GPU, RowConvFunc);
REGISTER_TYPED_FUNC(RowConvGrad, GPU, RowConvGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief The forward of row convolution.
*
* \param[out] out The output data and shape is h x d. h is the sum of
* time steps of all samples in one mini-batch.
* \param[in] in The input data and shape is h x d.
* \param[in] filter The filter and shape is k x d. The lookahead step
* number plus one equals k.
* \param[in] seq The sequence start positions.
*
*/
template <DeviceType DType>
void RowConv(typename Tensor<real, DType>::Matrix& out,
const typename Tensor<real, DType>::Matrix& in,
const typename Tensor<real, DType>::Matrix& filter,
const typename Tensor<int, DType>::Vector& seq);
/**
* \brief The backward of row convolution.
*
* \param[in] outG The gradient w.r.t output data.
* \param[in] in The input data.
* \param[in] filter The filter.
* \param[out] inG The gradient w.r.t input data.
* \param[out] filterG The gradient w.r.t filter.
* \param[in] seq The sequence start positions.
*
*/
template <DeviceType DType>
void RowConvGrad(const typename Tensor<real, DType>::Matrix& outG,
const typename Tensor<real, DType>::Matrix& in,
const typename Tensor<real, DType>::Matrix& filter,
typename Tensor<real, DType>::Matrix& inG,
typename Tensor<real, DType>::Matrix& filterG,
const typename Tensor<int, DType>::Vector& seq);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "RowConvOp.h"
namespace paddle {
template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConv(real* y, const real* x, const real* w,
const int* starts, const int height, const int width,
const int numSeq, const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int blky = blockDim.y;
const int gidx = blockIdx.x * blockDim.x;
__shared__ real sw[BLOCK_H][BLOCK_W];
for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
}
__syncthreads();
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
for (int j = tidy; j < steps; j += blky) {
real sum = 0;
int off = (start + j) * width;
for (int t = 0; t < context; ++t) {
if ((start + j + t) < end) {
int xoff = off + t * width;
real xVal = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0;
sum += sw[t][tidx] * xVal;
}
}
if (gidx + tidx < width) {
y[off + gidx + tidx] += sum;
}
}
}
}
__global__ void KeRowConv2(real* y, const real* x, const real* w,
const int* starts, const int height, const int width,
const int numSeq, const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int blky = blockDim.y;
const int gidx = blockIdx.x * blockDim.x;
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
for (int j = tidy; j < steps; j += blky) {
int off = (start + j) * width;
real sum = 0;
for (int t = 0; t < context && (start + j + t) < end; ++t) {
int xoff = off + t * width;
real xd = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0;
real wd = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0;
sum += wd * xd;
}
if (gidx + tidx < width) {
y[off + gidx + tidx] += sum;
}
}
}
}
template <>
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& in,
const GpuMatrix& filter,
const GpuIVector& seq) {
const size_t numSeq = seq.getSize() - 1;
const size_t contextLength = filter.getHeight();
const size_t height = in.getHeight();
const size_t width = in.getWidth();
real* y = out.getData();
const real* x = in.getData();
const real* w = filter.getData();
const int* starts = seq.getData();
dim3 dimBlock(32, 32);
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
if (contextLength <= 32) {
KeRowConv<32, 32><<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(y, x, w, starts, height, width, numSeq, contextLength);
} else {
KeRowConv2<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(y, x, w, starts, height, width, numSeq, contextLength);
}
CHECK_SYNC("RowConv");
}
template<int BLOCK_H, int BLOCK_W, int CONTEXT>
__global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
const int* starts, const int height, const int width, const int numSeq,
const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int blky = blockDim.y;
const int gidx = blockIdx.x * blockDim.x;
__shared__ real sh_x[BLOCK_W][BLOCK_H];
__shared__ real sh_dy[BLOCK_W][BLOCK_H + CONTEXT - 1];
__shared__ real sh_dw[CONTEXT][BLOCK_W];
if (tidy < context) {
sh_dw[tidy][tidx] = 0.0;
}
__syncthreads();
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
for (int j = tidy; j < size; j += BLOCK_H) {
int xoff = gidx + tidx;
int yoff = start + j;
// transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0;
__syncthreads();
if (tidy < (context - 1)) {
yoff = yoff - context + 1;
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0;
}
__syncthreads();
for (int t = 0; t < context; t++) {
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
__syncthreads();
// warp size and blockDim.x is 32.
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
__syncthreads();
if (tidx == 0) {
sh_dw[t][tidy] += val;
}
__syncthreads();
}
}
}
for (int t = tidy; (t < context) && ((gidx + tidx) < width); t += blky) {
dw[t * width + gidx + tidx] += sh_dw[t][tidx];
}
}
template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
const int* starts, const int height, const int width, const int numSeq,
const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int gidx = blockIdx.x * blockDim.x;
__shared__ real sh_x[BLOCK_H][BLOCK_W];
__shared__ real sh_dy[BLOCK_H][BLOCK_W];
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
for (int j = tidy; j < size; j += BLOCK_H) {
int xoff = gidx + tidx;
int yoff = start + j;
// transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
__syncthreads();
for (int t = 0; t < context; t++) {
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
__syncthreads();
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
__syncthreads();
// warp size and blockDim.x is 32.
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
__syncthreads();
if (tidx == 0 && (gidx + tidy) < width) {
dw[t*width + gidx + tidy] += val;
}
}
}
}
}
template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConvBwData(real* dx, const real* w, const real* dy,
const int* starts, const int height, const int width, const int numSeq,
const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int blky = blockDim.y;
const int gidx = blockIdx.x * blockDim.x;
__shared__ real sw[BLOCK_H][BLOCK_W];
for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
}
__syncthreads();
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
for (int j = tidy; j < steps; j += blky) {
real sum = 0;
int off = (start + j) * width;
for (int t = 0; t < context && (j - t) >= 0; ++t) {
int dyOff = off - t * width;
real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0;
sum += sw[t][tidx] * dyVal;
}
if (gidx + tidx < width) {
dx[off + gidx + tidx] += sum;
}
}
}
}
__global__ void KeRowConvBwData2(real* dx, const real* w, const real* dy,
const int* starts, const int height, const int width, const int numSeq,
const int context) {
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int blky = blockDim.y;
const int gidx = blockIdx.x * blockDim.x;
for (int i = 0; i < numSeq; ++i) {
const int start = starts[i];
const int end = starts[i + 1];
const int steps = end - start;
for (int j = tidy; j < steps; j += blky) {
real sum = 0;
int off = (start + j) * width;
for (int t = 0; t < context && (j - t) >= 0; ++t) {
int dyOff = off - t * width;
real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0;
real wVal = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0;
sum += wVal * dyVal;
}
if (gidx + tidx < width) {
dx[off + gidx + tidx] += sum;
}
}
}
}
template <>
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
const GpuMatrix& in,
const GpuMatrix& filter,
GpuMatrix& inG,
GpuMatrix& filterG,
const GpuIVector& seq) {
const size_t numSeq = seq.getSize() - 1;
const size_t contextLength = filter.getHeight();
const size_t height = in.getHeight();
const size_t width = in.getWidth();
const real* dy = outG.getData();
const real* x = in.getData();
const real* w = filter.getData();
const int* starts = seq.getData();
if (filterG) {
dim3 dimBlock(32, 32);
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
real* dw = filterG.getData();
if (contextLength <= 32) {
KeRowConvBwWeight<32, 32, 32>
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(dw, x, dy, starts, height, width, numSeq, contextLength);
} else {
KeRowConvBwWeight2<32, 32>
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(dw, x, dy, starts, height, width, numSeq, contextLength);
}
}
if (inG) {
real* dx = inG.getData();
dim3 dimBlock2(32, 32);
dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1);
if (contextLength <= 64) {
KeRowConvBwData<32, 64>
<<<dimGrid2, dimBlock2, 0, STREAM_DEFAULT>>>
(dx, w, dy, starts, height, width, numSeq, contextLength);
} else {
KeRowConvBwData2
<<<dimGrid2, dimBlock2, 0, STREAM_DEFAULT>>>
(dx, w, dy, starts, height, width, numSeq, contextLength);
}
}
CHECK_SYNC("RowConvGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConv", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}));
test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}),
ADD_TO);
test.run();
}
void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConvGrad", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}));
test.addOutputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}),
ADD_TO);
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, TensorShape{contextLength, dim}),
ADD_TO);
test.run();
}
TEST(RowConv, real) {
for (size_t numSamples : {17, 129, 2020}) {
for (size_t dim : {16, 512, 2560}) {
for (size_t context : {3, 19, 65}) {
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim
<< " context length=" << context;
testRowConvFw(numSamples, dim, context);
testRowConvBw(numSamples, dim, context);
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "RowConvLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(row_conv, RowConvLayer);
bool RowConvLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
contexLength_ = config_.inputs(0).row_conv_conf().context_length();
CHECK_EQ(inputLayers_.size(), 1UL);
weight_.reset(new Weight(contexLength_, getSize(), parameters_[0]));
createFunction(forward_, "RowConv", FuncConfig());
createFunction(backward_, "RowConvGrad", FuncConfig());
return true;
}
void RowConvLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = getInputValue(0);
size_t height = input->getHeight();
size_t width = input->getWidth();
CHECK_EQ(width, getSize());
resetOutput(height, width);
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
MatrixPtr w = weight_->getW();
wDims_ = TensorShape({w->getHeight(), w->getWidth()});
MatrixPtr outV = getOutputValue();
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), *startPos);
inputs.addArg(*w, wDims_);
outputs.addArg(*getOutputValue(), *startPos, ADD_TO);
{
REGISTER_TIMER_INFO("RowConvForward", getName().c_str());
forward_[0]->calc(inputs, outputs);
}
/* activation */ {
REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str());
forwardActivation();
}
}
void RowConvLayer::backward(const UpdateCallback& callback) {
/* Do derivation */ {
REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str());
backwardActivation();
}
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), *startPos);
inputs.addArg(*getInputValue(0), *startPos);
inputs.addArg(*weight_->getW(), wDims_);
MatrixPtr inGrad = getInputGrad(0);
MatrixPtr wGrad = weight_->getWGrad();
size_t h = getInputValue(0)->getHeight();
size_t w = getInputValue(0)->getWidth();
outputs.addArg(
inGrad ? (*inGrad) : *(Matrix::create(nullptr, h, w, false, useGpu_)),
*startPos,
ADD_TO);
outputs.addArg(
wGrad ? (*wGrad)
: *(Matrix::create(nullptr, contexLength_, w, false, useGpu_)),
wDims_,
ADD_TO);
{
REGISTER_TIMER_INFO("RowConvBackward", getName().c_str());
backward_[0]->calc(inputs, outputs);
}
{
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weight_->getParameterPtr()->incUpdate(callback);
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief Row Convolution Layer.
*/
class RowConvLayer : public Layer {
public:
explicit RowConvLayer(const LayerConfig& config) : Layer(config) {}
~RowConvLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
// Row convolution weight, context_lenght_ * fan_out.
// fan_out is the size of output feature.
std::unique_ptr<Weight> weight_;
// The step number to look ahead plus one equals contexLength_.
size_t contexLength_;
TensorShape wDims_;
};
} // namespace paddle
...@@ -1705,6 +1705,26 @@ TEST(Layer, TransLayer) { ...@@ -1705,6 +1705,26 @@ TEST(Layer, TransLayer) {
} }
} }
TEST(Layer, RowConvLayer) {
const int context = 3;
const int size = 512;
TestConfig config;
config.layerConfig.set_type("row_conv");
config.layerConfig.set_size(size);
config.layerConfig.set_active_type("sigmoid");
config.inputDefs.push_back(
{INPUT_SEQUENCE_DATA, "layer_0", size, context * size});
LayerInputConfig* input = config.layerConfig.add_inputs();
RowConvConfig* conv = input->mutable_row_conv_conf();
conv->set_context_length(context);
for (auto useGpu : {false, true}) {
testLayerGrad(config, "row_conv", 100, false, useGpu, false);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
...@@ -194,6 +194,10 @@ message MaxOutConfig { ...@@ -194,6 +194,10 @@ message MaxOutConfig {
required uint32 groups = 2; required uint32 groups = 2;
} }
message RowConvConfig {
required uint32 context_length = 1;
}
message ProjectionConfig { message ProjectionConfig {
required string type = 1; required string type = 1;
required string name = 2; required string name = 2;
...@@ -279,6 +283,7 @@ message LayerInputConfig { ...@@ -279,6 +283,7 @@ message LayerInputConfig {
optional SppConfig spp_conf = 12; optional SppConfig spp_conf = 12;
optional PriorBoxConfig priorbox_conf = 13; optional PriorBoxConfig priorbox_conf = 13;
optional PadConfig pad_conf = 14; optional PadConfig pad_conf = 14;
optional RowConvConfig row_conv_conf = 15;
} }
message LayerConfig { message LayerConfig {
......
...@@ -2081,6 +2081,23 @@ class MaxOutLayer(LayerBase): ...@@ -2081,6 +2081,23 @@ class MaxOutLayer(LayerBase):
g_layer_map[input_layer.name].width, out_channels) g_layer_map[input_layer.name].width, out_channels)
@config_layer('row_conv')
class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length
self.set_layer_size(input_layer.size)
psize = context_length * input_layer.size
dims = [context_length, input_layer.size]
self.create_input_parameter(0, psize, dims)
# key: cost type # key: cost type
# value: cost class # value: cost class
g_cost_map = {} g_cost_map = {}
......
...@@ -121,6 +121,7 @@ __all__ = [ ...@@ -121,6 +121,7 @@ __all__ = [
'smooth_l1_cost', 'smooth_l1_cost',
'layer_support', 'layer_support',
'multiplex_layer', 'multiplex_layer',
'row_conv_layer',
'dropout_layer', 'dropout_layer',
'prelu_layer', 'prelu_layer',
] ]
...@@ -179,17 +180,18 @@ class LayerType(object): ...@@ -179,17 +180,18 @@ class LayerType(object):
EOSID_LAYER = 'eos_id' EOSID_LAYER = 'eos_id'
RECURRENT_LAYER = 'recurrent' RECURRENT_LAYER = 'recurrent'
CONV_SHIFT_LAYER = 'conv_shift' CONV_SHIFT_LAYER = "conv_shift"
TENSOR_LAYER = 'tensor' TENSOR_LAYER = "tensor"
SEL_FC_LAYER = 'selective_fc' SEL_FC_LAYER = "selective_fc"
SAMPLING_ID_LAYER = 'sampling_id' SAMPLING_ID_LAYER = "sampling_id"
SLOPE_INTERCEPT_LAYER = 'slope_intercept' SLOPE_INTERCEPT_LAYER = "slope_intercept"
LINEAR_COMBINATION_LAYER = 'convex_comb' LINEAR_COMBINATION_LAYER = "convex_comb"
BLOCK_EXPAND = 'blockexpand' BLOCK_EXPAND = "blockexpand"
MAXOUT = 'maxout' MAXOUT = "maxout"
SPP_LAYER = 'spp' SPP_LAYER = "spp"
PAD_LAYER = 'pad' PAD_LAYER = "pad"
MULTIPLEX_LAYER = 'multiplex' MULTIPLEX_LAYER = "multiplex"
ROW_CONV_LAYER = "row_conv"
PRINT_LAYER = 'print' PRINT_LAYER = 'print'
PRIORBOX_LAYER = 'priorbox' PRIORBOX_LAYER = 'priorbox'
...@@ -5585,6 +5587,79 @@ def dropout_layer(input, dropout_rate, name=None): ...@@ -5585,6 +5587,79 @@ def dropout_layer(input, dropout_rate, name=None):
@wrap_name_default() @wrap_name_default()
@wrap_act_default(act=LinearActivation())
@wrap_param_attr_default()
@layer_support(DROPOUT)
def row_conv_layer(input,
context_len,
act=None,
name=None,
param_attr=None,
layer_attr=None):
"""
The row convolution is called lookahead convolution. It is firstly
introduced in paper of `Deep Speech 2: End-toEnd Speech Recognition
in English and Mandarin <https://arxiv.org/pdf/1512.02595v1.pdf>`_ .
The bidirectional RNN that learns representation for a sequence by
performing a forward and a backward pass through the entire sequence.
However, unlike unidirectional RNNs, bidirectional RNNs are challenging
to deploy in an online and low-latency setting. The lookahead convolution
incorporates information from future subsequences in a computationally
efficient manner to improve unidirectional recurrent neural networks.
The connection of row convolution is different form the 1D sequence
convolution. Assumed that, the future context-length is k, that is to say,
it can get the output at timestep t by using the the input feature from t-th
timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input
activations are d, the activations r_t for the new layer at time-step t are:
.. math::
r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}}
\quad \text{for} \quad (1 \leq i \leq d)
Note:
The `context_len` is `k + 1`. That is to say, the lookahead step
number plus one equals context_len.
.. code-block:: python
row_conv = row_conv_layer(input=input_layer, context_len=3)
:param input: The input layer.
:type input: LayerOutput
:param context_len: The context length equals the lookahead step number
plus one.
:type context_len: int
:param act: Activation Type. Default is linear activation.
:type act: BaseActivation
:param param_attr: The Parameter Attribute. If None, the parameter will be
initialized smartly. It's better set it by yourself.
:type param_attr: ParameterAttribute
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
assert context_len > 0, "the context_len must be greatet than 0."
Layer(
inputs=[Input(input.name, **param_attr.attr)],
name=name,
context_length=context_len,
type=LayerType.ROW_CONV_LAYER,
active_type=act.name,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.ROW_CONV_LAYER, input, activation=act, size=input.size)
@layer_support() @layer_support()
@wrap_name_default() @wrap_name_default()
@wrap_param_attr_default() @wrap_param_attr_default()
......
...@@ -6,6 +6,6 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos ...@@ -6,6 +6,6 @@ img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cos
test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer) test_prelu_layer test_row_conv)
export whole_configs=(test_split_datasource) export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 2560
active_type: ""
}
layers {
name: "__row_conv_layer_0__"
type: "maxout"
size: 2560
active_type: "relu"
inputs {
input_layer_name: "data"
input_parameter_name: "___row_conv_layer_0__.w0"
row_conv_conf {
context_length: 19
}
}
}
parameters {
name: "___row_conv_layer_0__.w0"
size: 48640
initial_mean: 0.0
initial_std: 0.229415733871
dims: 19
dims: 2560
initial_strategy: 0
initial_smart: true
}
input_layer_names: "data"
output_layer_names: "__row_conv_layer_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__row_conv_layer_0__"
input_layer_names: "data"
output_layer_names: "__row_conv_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
data = data_layer(name='data', size=2560)
row_conv = row_conv_layer(input=data, context_len=19, act=ReluActivation())
outputs(row_conv)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册