From 18cd1f2558338e3e999670cdfac7e1c61c1ea428 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sun, 4 Jun 2017 13:29:16 +0800 Subject: [PATCH] Fix bug and Python API. --- paddle/function/RowConvOp.cpp | 93 ++++++++------ paddle/function/RowConvOp.h | 18 ++- paddle/function/RowConvOpGpu.cu | 113 ++++++++++-------- paddle/function/RowConvOpTest.cpp | 27 ++--- paddle/gserver/layers/RowConvLayer.cpp | 2 +- paddle/gserver/layers/RowConvLayer.h | 4 +- python/paddle/trainer/config_parser.py | 17 +++ .../paddle/trainer_config_helpers/layers.py | 76 ++++++++++++ .../tests/configs/file_list.sh | 2 +- .../configs/protostr/test_row_conv.protostr | 41 +++++++ .../tests/configs/test_row_conv.py | 9 ++ 11 files changed, 295 insertions(+), 107 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py diff --git a/paddle/function/RowConvOp.cpp b/paddle/function/RowConvOp.cpp index 24b7e3cdffe..c3abb64971f 100644 --- a/paddle/function/RowConvOp.cpp +++ b/paddle/function/RowConvOp.cpp @@ -61,7 +61,7 @@ void RowConvGrad(const CpuMatrix& outG, size_t begin = starts[i]; size_t end = starts[i + 1]; size_t steps = end - begin; - for (size_t j = 0; j < contextLength; ++j) { + for (size_t j = 0; j < contextLength && (begin + j) < end; ++j) { MatrixPtr x = (const_cast(in)).subMatrix(begin + j, steps - j); MatrixPtr dy = @@ -81,7 +81,7 @@ void RowConvGrad(const CpuMatrix& outG, for (size_t j = 0; j < steps; ++j) { MatrixPtr dx = inG.subMatrix(begin + j, 1); for (size_t t = 0; t < contextLength; ++t) { - if ((int(j) - int(t)) >= 0) { + if (int(j - t) >= 0) { MatrixPtr dy = (const_cast(outG)).subMatrix(begin + j - t, 1); MatrixPtr w = (const_cast(filter)).subMatrix(t, 1); @@ -94,8 +94,37 @@ void RowConvGrad(const CpuMatrix& outG, } /** - * \brief TODO(qingqing) + * \brief The row convolution is called lookahead convolution. It is firstly + * introduced in deep-speech2 system. The bidirectional RNN that learns + * representation for a sequence by performing a forward and a backward pass + * through the entire sequence. However, unlike unidirectional RNNs, + * bidirectional RNNs are challenging to deploy in an online and low-latency + * setting. The lookahead convolution incorporates information from future + * subsequences in a computationally efficient manner to improve unidirectional + * recurrent neural networks. * + * The connection of row convolution is different form the 1D sequence + * convolution. Assumed that, the future context-length is k, that is to say, + * it can get the output at timestep t by using the the input feature from t-th + * timestep to (t+k)-th timestep. Assumed that the hidden dim of input + * activations are d, the activations r_t for the new layer at time-step t are: + * + * + * -- k + 1 + * r(t,i) = > W(i,j) * h(t+j-1, i), for (1 <= i <= d) + * -- j = 1 + * + * + * The weight shape is: (k + 1) x d + * Function Arguments: + * + * \param inputs[0] The input activations. + * \param inputs[0] The filter (or weight) and shape is (k+1) x d. + * \param outputs[1] The output activations. + * + * [1] Dario Amodei, etc. Deep Speech 2 : End-to-End Speech Recognition in + * English + * and Mandarin. https://arxiv.org/abs/1512.02595 */ template @@ -128,10 +157,21 @@ public: RowConv(outMat, inMat, wMat, seqId); } }; + /** - * \brief TODO(qingqing) + * \brief The backward of row convolution function. This function calculated + * the gradient w.r.t filter and the gradient w.r.t input activations(or data). * * Argument in this Function: + * + * \param inputs[0] The gradient w.r.t output activations. + * \param inputs[1] The input activations. + * \param inputs[2] The filter (or weight) and shape is (k+1) x d. + * \param outputs[0] The gradient w.r.t input activations. + * \param outputs[1] The gradient w.r.r filter. + * + * Abbreviation: + * w.r.t: with respect to. */ template @@ -140,12 +180,27 @@ public: void init(const FuncConfig& config) override {} void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + // check + CHECK_EQ(3UL, inputs.size()); + CHECK_EQ(2UL, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + CHECK(inputs[0].isSequenceArg() && inputs[1].isSequenceArg() && + outputs[0].isSequenceArg()) + << "SequenceArg required here."; + const auto outGrad = dynamic_cast(inputs[0]); const auto in = dynamic_cast(inputs[1]); const auto w = inputs[2]; auto inGrad = dynamic_cast(outputs[0]); auto wGrad = outputs[1]; + CHECK_EQ(in.shape().ndims(), 2UL); + CHECK_EQ(outGrad.shape().ndims(), 2UL); + CHECK_EQ(in.shape()[1], outGrad.shape()[1]); + CHECK_EQ(in.shape()[0], outGrad.shape()[0]); + CHECK_EQ(wGrad.shape()[1], in.shape()[1]); + const auto outGMat = outGrad.matrix(); const auto inMat = in.matrix(); const auto wMat = w.matrix(); @@ -157,37 +212,7 @@ public: : typename Tensor::Matrix(nullptr, 0, 0); const auto seqId = in.getSequenceId().vector(); - std::cout << "in:" << std::endl; - for (int i = 0; i < inMat.getHeight(); ++i) { - for (int j = 0; j < inMat.getWidth(); ++j) { - std::cout << outGMat.getElement(i, j) << " "; - } - std::cout << std::endl; - } - - std::cout << "w:" << std::endl; - for (int i = 0; i < wMat.getHeight(); ++i) { - for (int j = 0; j < wMat.getWidth(); ++j) { - std::cout << wMat.getElement(i, j) << " "; - } - std::cout << std::endl; - } - - std::cout << "w:" << std::endl; - for (int i = 0; i < seqId.getSize(); ++i) { - std::cout << seqId.getElement(i) << " "; - } - std::cout << std::endl; - RowConvGrad(outGMat, inMat, wMat, inGMat, wGMat, seqId); - - std::cout << std::endl << "out:" << std::endl; - for (int i = 0; i < inGMat.getHeight(); ++i) { - for (int j = 0; j < inGMat.getWidth(); ++j) { - std::cout << inGMat.getElement(i, j) << " "; - } - std::cout << std::endl; - } } }; diff --git a/paddle/function/RowConvOp.h b/paddle/function/RowConvOp.h index cd78ec724ab..2c5de6151aa 100644 --- a/paddle/function/RowConvOp.h +++ b/paddle/function/RowConvOp.h @@ -19,7 +19,14 @@ limitations under the License. */ namespace paddle { /** - * \brief TODO(qingqing) + * \brief The forward of row convolution. + * + * \param[out] out The output data and shape is h x d. h is the sum of + * time steps of all samples in one mini-batch. + * \param[in] in The input data and shape is h x d. + * \param[in] filter The filter and shape is k x d. The lookahead step + * number plus one equals k. + * \param[in] seq The sequence start positions. * */ template @@ -29,7 +36,14 @@ void RowConv(typename Tensor::Matrix& out, const typename Tensor::Vector& seq); /** - * \brief TODO(qingqing) + * \brief The backward of row convolution. + * + * \param[in] outG The gradient w.r.t output data. + * \param[in] in The input data. + * \param[in] filter The filter. + * \param[out] inG The gradient w.r.t input data. + * \param[out] filterG The gradient w.r.t filter. + * \param[in] seq The sequence start positions. * */ template diff --git a/paddle/function/RowConvOpGpu.cu b/paddle/function/RowConvOpGpu.cu index 5b0e065a21e..06e2c2baac2 100644 --- a/paddle/function/RowConvOpGpu.cu +++ b/paddle/function/RowConvOpGpu.cu @@ -96,11 +96,6 @@ void RowConv(GpuMatrix& out, const size_t height = in.getHeight(); const size_t width = in.getWidth(); - LOG(INFO) << numSeq; - LOG(INFO) << contextLength; - LOG(INFO) << height; - LOG(INFO) << width; - real* y = out.getData(); const real* x = in.getData(); const real* w = filter.getData(); @@ -108,7 +103,6 @@ void RowConv(GpuMatrix& out, dim3 dimBlock(32, 32); dim3 dimGrid(DIVUP(width, dimBlock.x), 1); - LOG(INFO) << dimGrid.x; if (contextLength <= 32) { KeRowConv<32, 32><<>> @@ -131,12 +125,12 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, const int blky = blockDim.y; const int gidx = blockIdx.x * blockDim.x; - __shared__ real sh_x[BLOCK_H][BLOCK_W]; - __shared__ real sh_dy[BLOCK_H][BLOCK_W]; + __shared__ real sh_x[BLOCK_W][BLOCK_H]; + __shared__ real sh_dy[BLOCK_W][BLOCK_H + CONTEXT - 1]; __shared__ real sh_dw[CONTEXT][BLOCK_W]; - for (int t = tidy; t < context; t += blky) { - sh_dw[t][tidx] = 0.0; + if (tidy < context) { + sh_dw[tidy][tidx] = 0.0; } __syncthreads(); @@ -144,21 +138,31 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, const int start = starts[i]; const int end = starts[i + 1]; const int steps = end - start; - for (int j = tidy; j < steps; j += BLOCK_H) { + const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H; + for (int j = tidy; j < size; j += BLOCK_H) { int xoff = gidx + tidx; int yoff = start + j; // transpose - sh_x[tidx][tidy] = xoff < width && yoff < end ? x[yoff * width + xoff] : 0.0; - sh_dy[tidx][tidy] = xoff < width && yoff < end ? dy[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0; + __syncthreads(); + if (tidy < (context - 1)) { + yoff = yoff - context + 1; + sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0; + } __syncthreads(); for (int t = 0; t < context; t++) { - real val = tidx + t < blockDim.x ? sh_x[tidy][tidx + t] * sh_dy[tidy][tidx]: 0.0; + real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t]; + __syncthreads(); // warp size and blockDim.x is 32. - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down(val, offset); - } + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + __syncthreads(); if (tidx == 0) { sh_dw[t][tidy] += val; } @@ -167,7 +171,7 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, } } - for (int t = tidy; t < context && (gidx + tidx) < width; t += blky) { + for (int t = tidy; (t < context) && ((gidx + tidx) < width); t += blky) { dw[t * width + gidx + tidx] += sh_dw[t][tidx]; } } @@ -188,21 +192,30 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy, const int start = starts[i]; const int end = starts[i + 1]; const int steps = end - start; - for (int j = 0; j < steps; j += BLOCK_H) { + + const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H; + for (int j = tidy; j < size; j += BLOCK_H) { int xoff = gidx + tidx; int yoff = start + j; // transpose - sh_x[tidx][tidy] = xoff < width && yoff < end ? x[yoff * width + xoff] : 0.0; - sh_dy[tidx][tidy] = xoff < width && yoff < end ? dy[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; __syncthreads(); for (int t = 0; t < context; t++) { - real val = tidx + t < blockDim.x ? sh_x[tidy][tidx + t] * sh_dy[tidy][tidx]: 0.0; + sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; + __syncthreads(); + + real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; + __syncthreads(); // warp size and blockDim.x is 32. - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down(val, offset); - } + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + __syncthreads(); + if (tidx == 0 && (gidx + tidy) < width) { dw[t*width + gidx + tidy] += val; } @@ -293,34 +306,36 @@ void RowConvGrad(const GpuMatrix& outG, const real* dy = outG.getData(); const real* x = in.getData(); const real* w = filter.getData(); - real* dx = inG.getData(); - real* dw = filterG.getData(); const int* starts = seq.getData(); - dim3 dimBlock(32, 32); - dim3 dimGrid(DIVUP(width, dimBlock.x), 1); - - if (contextLength <= 16) { - KeRowConvBwWeight<32, 32, 16> - <<>> - (dw, x, dy, starts, height, width, numSeq, contextLength); - } else { - KeRowConvBwWeight2<32, 32> - <<>> - (dw, x, dy, starts, height, width, numSeq, contextLength); + if (filterG) { + dim3 dimBlock(32, 32); + dim3 dimGrid(DIVUP(width, dimBlock.x), 1); + real* dw = filterG.getData(); + if (contextLength <= 16) { + KeRowConvBwWeight<32, 32, 16> + <<>> + (dw, x, dy, starts, height, width, numSeq, contextLength); + } else { + KeRowConvBwWeight2<32, 32> + <<>> + (dw, x, dy, starts, height, width, numSeq, contextLength); + } } - - dim3 dimBlock2(32, 32); - dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1); - if (contextLength <= 64) { - KeRowConvBwData<32, 64> - <<>> - (dx, w, dy, starts, height, width, numSeq, contextLength); - } else { - KeRowConvBwData2 - <<>> - (dx, w, dy, starts, height, width, numSeq, contextLength); + if (inG) { + real* dx = inG.getData(); + dim3 dimBlock2(32, 32); + dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1); + if (contextLength <= 64) { + KeRowConvBwData<32, 64> + <<>> + (dx, w, dy, starts, height, width, numSeq, contextLength); + } else { + KeRowConvBwData2 + <<>> + (dx, w, dy, starts, height, width, numSeq, contextLength); + } } CHECK_SYNC("RowConvGrad"); diff --git a/paddle/function/RowConvOpTest.cpp b/paddle/function/RowConvOpTest.cpp index 9898df1a974..1c95d3ff2cc 100644 --- a/paddle/function/RowConvOpTest.cpp +++ b/paddle/function/RowConvOpTest.cpp @@ -47,23 +47,16 @@ void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) { } TEST(RowConv, real) { - // for (size_t numSamples : {17, 129}) { - // for (size_t dim : {16, 248}) { - // for (size_t context: {3, 7, 65}) { - LOG(INFO) << "==========="; - // for (size_t numSamples : {17}) { - // for (size_t dim : {16}) { - // for (size_t context: {3}) { - size_t numSamples = 17; - size_t dim = 16; - size_t context = 3; - LOG(INFO) << " numSamples=" << numSamples << " dim=" << dim - << " context length=" << context; - testRowConvFw(numSamples, dim, context); - // testRowConvBw(numSamples, dim, context); - // } - // } - // } + for (size_t numSamples : {17, 129, 2020}) { + for (size_t dim : {16, 512, 2560}) { + for (size_t context : {3, 19, 65}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim + << " context length=" << context; + testRowConvFw(numSamples, dim, context); + testRowConvBw(numSamples, dim, context); + } + } + } } } // namespace paddle diff --git a/paddle/gserver/layers/RowConvLayer.cpp b/paddle/gserver/layers/RowConvLayer.cpp index d4b14062977..5302e0e1a8f 100644 --- a/paddle/gserver/layers/RowConvLayer.cpp +++ b/paddle/gserver/layers/RowConvLayer.cpp @@ -75,7 +75,7 @@ void RowConvLayer::backward(const UpdateCallback& callback) { BufferArgs outputs; inputs.addArg(*getOutputGrad(), *startPos); inputs.addArg(*getInputValue(0), *startPos); - inputs.addArg(*weight_->getW(), *startPos); + inputs.addArg(*weight_->getW(), wDims_); MatrixPtr inGrad = getInputGrad(0); MatrixPtr wGrad = weight_->getWGrad(); diff --git a/paddle/gserver/layers/RowConvLayer.h b/paddle/gserver/layers/RowConvLayer.h index 05be6ca6a9b..b3bdda2f354 100644 --- a/paddle/gserver/layers/RowConvLayer.h +++ b/paddle/gserver/layers/RowConvLayer.h @@ -37,9 +37,7 @@ protected: // fan_out is the size of output feature. std::unique_ptr weight_; - // std::unique_ptr biases_; - - // how many steps to look ahead + // The step number to look ahead plus one equals contexLength_. size_t contexLength_; TensorShape wDims_; }; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5d540664a7f..9066ce05f33 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2081,6 +2081,23 @@ class MaxOutLayer(LayerBase): g_layer_map[input_layer.name].width, out_channels) +@config_layer('row_conv') +class RowConvLayer(LayerBase): + def __init__(self, name, inputs, context_length, **xargs): + super(RowConvLayer, self).__init__( + name, 'maxout', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'TransLayer must have one and only one input') + input_layer = self.get_input_layer(0) + row_conv_conf = self.config.inputs[0].row_conv_conf + row_conv_conf.context_length = context_length + self.set_layer_size(input_layer.size) + psize = context_length * input_layer.size + dims = [context_length, input_layer.size] + self.create_input_parameter(0, psize, dims) + + # key: cost type # value: cost class g_cost_map = {} diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 81cce31fecc..47b62328772 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -120,6 +120,7 @@ __all__ = [ 'smooth_l1_cost', 'layer_support', 'multiplex_layer', + 'row_conv_layer', ] @@ -187,6 +188,7 @@ class LayerType(object): SPP_LAYER = "spp" PAD_LAYER = "pad" MULTIPLEX_LAYER = "multiplex" + ROW_CONV_LAYER = "row_conv" PRINT_LAYER = "print" PRIORBOX_LAYER = "priorbox" @@ -5528,3 +5530,77 @@ def multiplex_layer(input, name=None, layer_attr=None): layer_type=LayerType.MULTIPLEX_LAYER, parents=input, size=l.config.size) + + +@wrap_name_default() +@wrap_act_default(act=LinearActivation()) +@wrap_param_attr_default() +@layer_support(DROPOUT) +def row_conv_layer(input, + context_len, + act=None, + name=None, + param_attr=None, + layer_attr=None): + """ + + The row convolution is called lookahead convolution. It is firstly + introduced in paper of `Deep Speech 2: End-toEnd Speech Recognition + in English and Mandarin `_ . + + The bidirectional RNN that learns representation for a sequence by + performing a forward and a backward pass through the entire sequence. + However, unlike unidirectional RNNs, bidirectional RNNs are challenging + to deploy in an online and low-latency setting. The lookahead convolution + incorporates information from future subsequences in a computationally + efficient manner to improve unidirectional recurrent neural networks. + + The connection of row convolution is different form the 1D sequence + convolution. Assumed that, the future context-length is k, that is to say, + it can get the output at timestep t by using the the input feature from t-th + timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input + activations are d, the activations r_t for the new layer at time-step t are: + + .. math:: + + r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}} + \quad \text{for} \quad (1 \leq i \leq d) + + Note: + The `context_len` is `k + 1`. That is to say, the lookahead step + number plus one equals context_len. + + + .. code-block:: python + + row_conv = row_conv_layer(input=input_layer, context_len=3) + + + :param input: The input layer. + :type input: LayerOutput + :param context_len: The context length equals the lookahead step number + plus one. + :type context_len: int + :param act: Activation Type. Default is linear activation. + :type act: BaseActivation + :param param_attr: The Parameter Attribute. If None, the parameter will be + initialized smartly. It's better set it by yourself. + :type param_attr: ParameterAttribute + :param layer_attr: Extra Layer config. + :type layer_attr: ExtraLayerAttribute|None + :return: LayerOutput object. + :rtype: LayerOutput + + """ + assert isinstance(input, LayerOutput) + assert context_len > 0, "the context_len must be greatet than 0." + + Layer( + inputs=[Input(input.name, **param_attr.attr)], + name=name, + context_length=context_len, + type=LayerType.ROW_CONV_LAYER, + active_type=act.name, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, LayerType.ROW_CONV_LAYER, input, activation=act, size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 981ccbf2483..db3d3c65505 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -5,6 +5,6 @@ last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops -test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer) +test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_row_conv) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr new file mode 100644 index 00000000000..9ec15d2a19e --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_conv.protostr @@ -0,0 +1,41 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 2560 + active_type: "" +} +layers { + name: "__row_conv_layer_0__" + type: "maxout" + size: 2560 + active_type: "relu" + inputs { + input_layer_name: "data" + input_parameter_name: "___row_conv_layer_0__.w0" + row_conv_conf { + context_length: 19 + } + } +} +parameters { + name: "___row_conv_layer_0__.w0" + size: 48640 + initial_mean: 0.0 + initial_std: 0.229415733871 + dims: 19 + dims: 2560 + initial_strategy: 0 + initial_smart: true +} +input_layer_names: "data" +output_layer_names: "__row_conv_layer_0__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__row_conv_layer_0__" + input_layer_names: "data" + output_layer_names: "__row_conv_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py b/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py new file mode 100644 index 00000000000..ab33c496b06 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_row_conv.py @@ -0,0 +1,9 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=2560) + +row_conv = row_conv_layer(input=data, context_len=19, act=ReluActivation()) + +outputs(row_conv) -- GitLab