WarpCTCLayer.h 2.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"

namespace paddle {

/**
 * @brief A layer integrating the open-source warp-ctc library
 *        <https://github.com/baidu-research/warp-ctc> to compute connectionist
 *        temporal classification cost.
 *
 * The config file api is warp_ctc_layer.
 */
class WarpCTCLayer : public Layer {
public:
  explicit WarpCTCLayer(const LayerConfig& config) : Layer(config) {}
  ~WarpCTCLayer() {}

  virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
  virtual void forward(PassType passType);
  virtual void backward(const UpdateCallback& callback);

protected:
  /**
   * sequence matrix and batch matrix copy:
   * sequence (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
   * batch    (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
   */
  void seq2batchPadding(const MatrixPtr& seqValue,
                        MatrixPtr& batchValue,
                        const ICpuGpuVectorPtr& seqStartPositions);
  void batch2seqPadding(const MatrixPtr& seqValue,
                        MatrixPtr& batchValue,
                        const ICpuGpuVectorPtr& seqStartPositions,
                        bool normByTimes);

protected:
  size_t numClasses_;
  size_t blank_;
  size_t maxSequenceLength_;
  bool normByTimes_;

  MatrixPtr batchValue_;
  MatrixPtr batchGrad_;
  VectorPtr workspace_;

  IVectorPtr cpuLabels_;
  MatrixPtr cpuCosts_;
};

}  // namespace paddle