LinearChainCRF.h 3.0 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/math/Matrix.h"

namespace paddle {

class LinearChainCRF {
public:
23
  /**
24
   * The size of para must be \f$(numClasses + 2) * numClasses\f$.
25 26 27 28 29 30 31 32
   * The first numClasses values of para are for starting weights (\f$a\f$).
   * The next numClasses values of para are for ending weights (\f$b\f$),
   * The remaning values are for transition weights (\f$w\f$).
   *
   * The probability of a state sequence s of length \f$L\f$ is defined as:
   * \f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
   *                  + \sum_{l=1}^L x_{s_l}
   *                  + \sum_{l=2}^L w_{s_{l-1},s_l})\f$
33 34
   * where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over
   * all possible
35
   * sequences is \f$1\f$, and \f$x\f$ is the input feature to the CRF.
Z
zhangjinchao01 已提交
36
   */
37
  LinearChainCRF(int numClasses, real* para);
Z
zhangjinchao01 已提交
38

39 40 41 42
  /**
   * Calculate the negative log likelihood of s given x.
   * The size of x must be length * numClasses. Each consecutive numClasses
   * values are the features for one time step.
Z
zhangjinchao01 已提交
43 44 45
   */
  real forward(real* x, int* s, int length);

46 47 48 49
  /**
   * Calculate the gradient with respect to x, a, b, and w.
   * backward() can only be called after a corresponding call to forward() with
   * the same x, s and length.
50 51 52 53
   * The gradient with respect to a, b, and w will not be calculated if
   * needWGrad is false.
   * @note Please call getWGrad() and getXGrad() to get the gradient with
   * respect to (a, b, w) and x respectively.
Z
zhangjinchao01 已提交
54
   */
55
  void backward(real* x, int* s, int length, bool needWGrad);
Z
zhangjinchao01 已提交
56

57 58
  /**
   * Find the most probable sequence given x. The result will be stored in s.
Z
zhangjinchao01 已提交
59 60 61
   */
  void decode(real* x, int* s, int length);

62 63 64 65 66 67 68 69 70 71 72 73
  /*
   * Return the gradient with respect to (a, b, w). It can only be called after
   * a corresponding call to backward().
   */
  MatrixPtr getWGrad() { return matWGrad_; }

  /*
   * Return the gradient with respect to x. It can only be called after a
   * corresponding call to backward().
   */
  MatrixPtr getXGrad() { return matGrad_; }

Z
zhangjinchao01 已提交
74 75 76 77 78
protected:
  int numClasses_;
  MatrixPtr a_;
  MatrixPtr b_;
  MatrixPtr w_;
79
  MatrixPtr matWGrad_;
Z
zhangjinchao01 已提交
80 81 82 83 84 85
  MatrixPtr da_;
  MatrixPtr db_;
  MatrixPtr dw_;
  MatrixPtr ones_;

  MatrixPtr expX_;
86
  MatrixPtr matGrad_;
Z
zhangjinchao01 已提交
87 88 89 90 91 92 93 94 95 96 97
  MatrixPtr alpha_;
  MatrixPtr beta_;
  MatrixPtr maxX_;
  MatrixPtr expW_;

  // track_(k,i) = j means that the best sequence at time k for class i comes
  // from the sequence at time k-1 for class j
  IVectorPtr track_;
};

}  // namespace paddle