FactorizationMachineLayer.h 2.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"
X
Xin Pan 已提交
18
#include "paddle/legacy/math/Matrix.h"
X
Xin Pan 已提交
19
#include "paddle/legacy/utils/ThreadLocal.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

namespace paddle {
/**
 * @brief The Factorization Machine models pairwise (order-2) feature
 * interactions as inner product of the learned latent vectors corresponding
 * to each input feature.
 *
 * The Factorization Machine can effectively capture feature interactions
 * especially when the input is sparse. While in principle FM can model higher
 * order feature interaction, in practice usually only order-2 feature
 * interactions are considered. The Factorization Machine Layer here only
 * computes the order-2 interations with the formula:
 *
 * \f[
 *     y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j
 * \f]
 *
37 38
 * The detailed calculation for forward and backward can be found at this paper:
 *
39
 *     Factorization machines.
40
 *
41 42 43 44
 * The config file api is factorization_machine.
 */

class FactorizationMachineLayer : public Layer {
W
Wu Yi 已提交
45
 protected:
46 47 48
  // The latent vectors, shape: (size, factorSize_)
  // Each row of the latentVectors_ matrix is the latent vector
  // corresponding to one input feature dimension
49
  std::unique_ptr<Weight> latentVectors_;
50
  // The hyperparameter that defines the dimensionality of the factorization
51 52
  size_t factorSize_;

W
Wu Yi 已提交
53
 private:
54 55 56 57 58 59 60
  // Store the square values of the letent vectors matrix
  MatrixPtr latentVectorsSquare_;
  // Store the square values of input matrix
  MatrixPtr inputSquare_;
  // The result of input matrix * latent vector matrix that will be used in
  // both forward and backward step
  MatrixPtr inputMulFactor_;
61
  // Store temporary calculation result
W
wangmeng28 已提交
62
  MatrixPtr tmpOut_;
63
  MatrixPtr tmpSum_;
64
  MatrixPtr tmpInput_;
65 66
  // Negative identity matrix
  MatrixPtr negOnes_;
W
wangmeng28 已提交
67

W
Wu Yi 已提交
68
 public:
69 70 71 72 73 74 75 76 77 78 79 80
  explicit FactorizationMachineLayer(const LayerConfig& config)
      : Layer(config) {}
  ~FactorizationMachineLayer() {}

  bool init(const LayerMap& layerMap,
            const ParameterMap& parameterMap) override;

  void forward(PassType passType) override;
  void backward(const UpdateCallback& callback = nullptr) override;
};

}  // namespace paddle