diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index d4d182f6692e09b3e40f3620b77d9a0f20ec5af3..c3f9c18d0663a7a24880b441981875c1e4f015aa 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -54,7 +54,7 @@ img_conv .. _api_v2.layer_context_projection: -context_projection +context_projection ------------------ .. autoclass:: paddle.v2.layer.context_projection :noindex: @@ -70,7 +70,7 @@ Image Pooling Layer img_pool -------- .. autoclass:: paddle.v2.layer.img_pool - :noindex: + :noindex: spp --- @@ -104,7 +104,7 @@ sum_to_one_norm --------------- .. autoclass:: paddle.v2.layer.sum_to_one_norm :noindex: - + cross_channel_norm ------------------ .. autoclass:: paddle.v2.layer.cross_channel_norm @@ -114,7 +114,7 @@ row_l2_norm ----------- .. autoclass:: paddle.v2.layer.row_l2_norm :noindex: - + Recurrent Layers ================ @@ -415,6 +415,13 @@ multiplex .. autoclass:: paddle.v2.layer.multiplex :noindex: +Factorization Machine Layer +============================ + +factorization_machine +--------------------- +.. autoclass:: paddle.v2.layer.factorization_machine + :noindex: Slicing and Joining Layers ========================== diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be26b9ba88c279036f73b0a0baaff164755fe067 --- /dev/null +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "FactorizationMachineLayer.h" +#include +#include +#include "paddle/math/SparseMatrix.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(factorization_machine, FactorizationMachineLayer); + +bool FactorizationMachineLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + factorSize_ = config_.factor_size(); + + /* initialize the latentVectors_ */ + CHECK_EQ(inputLayers_.size(), 1UL); + size_t inputSize = inputLayers_[0]->getSize(); + CHECK_EQ(parameters_[0]->getSize(), inputSize * factorSize_); + latentVectors_ = std::unique_ptr( + new Weight(inputSize, factorSize_, parameters_[0])); + + return true; +} + +void FactorizationMachineLayer::forward(PassType passType) { + Layer::forward(passType); + + const MatrixPtr& inputV = getInputValue(0); + + size_t batchSize = inputV->getHeight(); + size_t outputSize = getSize(); + size_t inputSize = inputLayers_[0]->getSize(); + reserveOutput(batchSize, outputSize); + + MatrixPtr outV = getOutputValue(); + + Matrix::resizeOrCreate( + latentVectorsSquare_, inputSize, factorSize_, false, useGpu_); + Matrix::resizeOrCreate( + inputMulFactor_, batchSize, factorSize_, false, useGpu_); + Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_); + + REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str()); + inputMulFactor_->mul(*inputV, *latentVectors_->getW()); + inputMulFactor_->square2(*tmpOut_); + outV->sumRows(*tmpOut_, 0.5, 0); + + if (dynamic_cast(inputV.get())) { + Matrix::resizeOrCreateSparseMatrix(inputSquare_, + inputV->getHeight(), + inputV->getWidth(), + inputV->getElementCnt(), + inputV->getValueType()); + inputSquare_->copyFrom(*inputV); + (dynamic_cast(inputSquare_.get()))->square2(); + } else { + Matrix::resizeOrCreate( + inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_); + inputV->square2(*inputSquare_); + } + latentVectors_->getW()->square2(*latentVectorsSquare_); + tmpOut_->mul(*inputSquare_, *latentVectorsSquare_); + outV->sumRows(*tmpOut_, -0.5, 1.0); + + /* activation */ { + REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str()); + forwardActivation(); + } +} + +void FactorizationMachineLayer::backward(const UpdateCallback& callback) { + /* Do derivation */ { backwardActivation(); } + + const MatrixPtr& inputV = getInputValue(0); + const MatrixPtr& oGrad = getOutputGrad(); + + Matrix::resizeOrCreate( + tmpSum_, 1, latentVectors_->getW()->getHeight(), false, useGpu_); + MatrixPtr tmpSumTrans = Matrix::create(tmpSum_->getRowBuf(0), + latentVectors_->getW()->getHeight(), + 1, + false, + useGpu_); + + /* Calculate the gradients of the latentVectors_ matrix */ + if (latentVectors_->getWGrad()) { + if (dynamic_cast(inputV.get())) { + Matrix::resizeOrCreateSparseMatrix(tmpInput_, + inputV->getHeight(), + inputV->getWidth(), + inputV->getElementCnt()); + + CpuSparseMatrix* sparseInputV = + dynamic_cast(inputV.get()); + CpuSparseMatrix* sparseInputSquare = + dynamic_cast(inputSquare_.get()); + CpuSparseMatrix* sparseTmpInput = + dynamic_cast(tmpInput_.get()); + sparseTmpInput->copyFrom(*sparseInputV); + + sparseTmpInput->rowScale(0, *sparseInputV, *oGrad); + latentVectors_->getWGrad()->mul( + *sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1); + sparseTmpInput->rowScale(0, *sparseInputSquare, *oGrad); + + Matrix::resizeOrCreate(negOnes_, 1, inputV->getHeight(), false, useGpu_); + negOnes_->zeroMem(); + negOnes_->add(-1); + tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0); + } else { + Matrix::resizeOrCreate( + tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_); + + tmpInput_->rowScale(0, *inputV, *oGrad); + latentVectors_->getWGrad()->mul( + *tmpInput_->getTranspose(), *inputMulFactor_, 1, 1); + tmpInput_->rowScale(0, *inputSquare_, *oGrad); + + tmpSum_->sumCols(*tmpInput_, -1, 0); + } + + latentVectors_->getWGrad()->addRowScale( + 0, *latentVectors_->getW(), *tmpSumTrans); + + /* Increasing the number of gradient */ + latentVectors_->getParameterPtr()->incUpdate(callback); + } + + /* Calculate the input layers gradient */ + MatrixPtr inGrad = getInputGrad(0); + if (inGrad != NULL) { + inGrad->mul( + *inputMulFactor_, *latentVectors_->getW()->getTranspose(), 1, 1); + tmpSumTrans->sumRows(*latentVectorsSquare_, -1, 0); + inGrad->addColScale(0, *inputV, *tmpSum_); + inGrad->rowScale(0, *inGrad, *oGrad); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/FactorizationMachineLayer.h b/paddle/gserver/layers/FactorizationMachineLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..df20a49934d5dd444f127842c8fdb7c77f4ebeb1 --- /dev/null +++ b/paddle/gserver/layers/FactorizationMachineLayer.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" +#include "paddle/math/Matrix.h" +#include "paddle/utils/ThreadLocal.h" + +namespace paddle { +/** + * @brief The Factorization Machine models pairwise (order-2) feature + * interactions as inner product of the learned latent vectors corresponding + * to each input feature. + * + * The Factorization Machine can effectively capture feature interactions + * especially when the input is sparse. While in principle FM can model higher + * order feature interaction, in practice usually only order-2 feature + * interactions are considered. The Factorization Machine Layer here only + * computes the order-2 interations with the formula: + * + * \f[ + * y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j + * \f] + * + * The detailed calculation for forward and backward can be found at this paper: + * + * Factorization machines. + * + * The config file api is factorization_machine. + */ + +class FactorizationMachineLayer : public Layer { +protected: + // The latent vectors, shape: (size, factorSize_) + // Each row of the latentVectors_ matrix is the latent vector + // corresponding to one input feature dimension + std::unique_ptr latentVectors_; + // The hyperparameter that defines the dimensionality of the factorization + size_t factorSize_; + +private: + // Store the square values of the letent vectors matrix + MatrixPtr latentVectorsSquare_; + // Store the square values of input matrix + MatrixPtr inputSquare_; + // The result of input matrix * latent vector matrix that will be used in + // both forward and backward step + MatrixPtr inputMulFactor_; + // Store temporary calculation result + MatrixPtr tmpOut_; + MatrixPtr tmpSum_; + MatrixPtr tmpInput_; + // Negative identity matrix + MatrixPtr negOnes_; + +public: + explicit FactorizationMachineLayer(const LayerConfig& config) + : Layer(config) {} + ~FactorizationMachineLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index cacf10692942f5eca2f6c498183f4acc00768460..a9fc733d1de441dea9f817c18ec65743836c2f23 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2464,6 +2464,25 @@ TEST(Layer, L2DistanceLayer) { } } +void testFactorizationMachineLayer(InputType type, bool useGpu) { + const int FACTOR_SIZE = 10; + TestConfig config; + config.layerConfig.set_type("factorization_machine"); + config.layerConfig.set_factor_size(FACTOR_SIZE); + config.layerConfig.set_size(1); + config.biasSize = 0; + config.inputDefs.push_back({type, "layer_0", 128, 1280}); + config.layerConfig.add_inputs(); + testLayerGrad(config, "factorization_machine", 16, false, useGpu, false); +} + +TEST(Layer, FactorizationMachineLayer) { + for (auto useGpu : {false, true}) { + testFactorizationMachineLayer(INPUT_DATA, useGpu); + } + testFactorizationMachineLayer(INPUT_SPARSE_FLOAT_VALUE_DATA, false); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index bf62229c03bb1d6e2bdf86d8c56a8157938fb832..dc6979cf5a5229fb09866189f28217889d58c2d0 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -260,6 +260,35 @@ void CpuSparseMatrix::printOneRow(std::ostream& os, size_t idx) const { os << ";"; } +void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) { + CHECK(getFormat() != SPARSE_CSC) << "Not supported"; + CHECK_EQ(height_, b.getHeight()); + CHECK_EQ(width_, b.getWidth()); + real* A = getValue(); + real* B = b.getValue(); + if (b.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < height_; i++) { + size_t start = getRowStartIdx(i); + size_t end = getRowStartIdx(i + 1); + CHECK_EQ(start, b.getRowStartIdx(i)); + CHECK_EQ(end, b.getRowStartIdx(i + 1)); + for (size_t j = start; j < end; j++) { + A[j] = B[j] * c.getElement(i, cCol); + } + } + } else if (b.getValueType() == NO_VALUE) { + for (size_t i = 0; i < height_; i++) { + size_t start = getRowStartIdx(i); + size_t end = getRowStartIdx(i + 1); + CHECK_EQ(start, b.getRowStartIdx(i)); + CHECK_EQ(end, b.getRowStartIdx(i + 1)); + for (size_t j = start; j < end; j++) { + A[j] = c.getElement(i, cCol); + } + } + } +} + void CpuSparseMatrix::randomizeUniform() { CHECK_LE(elementCnt_, height_ * width_); if (valueType_ == FLOAT_VALUE) { diff --git a/paddle/math/CpuSparseMatrix.h b/paddle/math/CpuSparseMatrix.h index aad1348353d558abca72ed0fa5cf943237e3ac78..522b436a2a69179d3f4f17c919d5ba024102db7b 100644 --- a/paddle/math/CpuSparseMatrix.h +++ b/paddle/math/CpuSparseMatrix.h @@ -239,6 +239,15 @@ public: const unsigned int* cols, const real* values); + /** + * @brief this_row = b_row * c_row[cCol] + * + * @param[in] cCol the column of matrix c used to scale each row of b + * @param[in] b CpuSparseMatrix + * @param[in] c Matrix + */ + void rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c); + void randomizeUniform(); void copyFrom(const GpuSparseMatrix& src, hl_stream_t stream); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index e2f5592248fd0b6166c2d11af02cef7815673def..2fcdbbc8bd671f8ae911cf82c7a91091f252a82f 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -544,6 +544,9 @@ message LayerConfig { // for batch normalization layer // The small constant added to the variance to improve numeric stability. optional double epsilon = 60 [ default = 0.00001 ]; + + // for factorization machine layer + optional uint32 factor_size = 61; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index cfe2a34a1f34a9c828486a7a6dbe320f230bb986..267393d611d6fad1a77a6c1e0a45be4be1e34731 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3870,6 +3870,21 @@ class ScaleSubRegionLayer(LayerBase): image_conf.channels) +@config_layer('factorization_machine') +class FactorizationMachineLayer(LayerBase): + def __init__(self, name, inputs, factor_size, **xargs): + super(FactorizationMachineLayer, self).__init__( + name, 'factorization_machine', size=1, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'factorization machine layer must have one and only one input.') + self.config.factor_size = factor_size + input_layer = self.get_input_layer(0) + psize = input_layer.size * factor_size + dims = [input_layer.size, factor_size] + self.create_input_parameter(0, psize, dims) + + # Deprecated, use a new layer specific class instead @config_func def Layer(name, type, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 469e667e80900b26578db6199e6426be8d0e5945..5c711bd769bdfe5eb514d0ff84a358b7e36170cf 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -148,6 +148,7 @@ __all__ = [ 'resize_layer', 'sub_seq_layer', 'scale_sub_region_layer', + 'factorization_machine', ] @@ -264,6 +265,8 @@ class LayerType(object): SCALE_SUB_REGION_LAYER = 'scale_sub_region' + FACTORIZATION_MACHINE = 'factorization_machine' + @staticmethod def is_layer_type(type_name): """ @@ -7403,3 +7406,73 @@ def scale_sub_region_layer(input, indices, value, name=None): parents=[input, indices], num_filters=input.num_filters, size=input.size) + + +@wrap_name_default() +@wrap_act_default(act=LinearActivation()) +@wrap_param_attr_default() +@layer_support() +def factorization_machine(input, + factor_size, + act=None, + name=None, + param_attr=None, + layer_attr=None): + """ + The Factorization Machine models pairwise feature interactions as inner + product of the learned latent vectors corresponding to each input feature. + The Factorization Machine can effectively capture feature interactions + especially when the input is sparse. + + This implementation only consider the 2-order feature interactions using + Factorization Machine with the formula: + + .. math:: + y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j + + Note: + X is the input vector with size n. V is the factor matrix. Each row of V + is the latent vector corresponding to each input dimesion. The size of + each latent vector is k. + + For details of Factorization Machine, please refer to the paper: + Factorization machines. + + .. code-block:: python + first_order = paddle.layer.fc(input=input, + size=1, + act=paddle.activation.Linear()) + second_order = paddle.layer.factorization_machine(input=input, + factor_size=10) + fm = paddle.layer.addto(input=[first_order, second_order], + act=paddle.activation.Linear(), + bias_attr=False) + + :param input: The input layer. Supported input types: all input data types + on CPU, and only dense input types on GPU. + :type input: LayerOutput + :param factor_size: The hyperparameter that defines the dimensionality of + the latent vector size. + :type context_len: int + :param act: Activation Type. Default is linear activation. + :type act: BaseActivation + :param param_attr: The parameter attribute. See ParameterAttribute for + details. + :type param_attr: ParameterAttribute + :param layer_attr: Extra Layer config. + :type layer_attr: ExtraLayerAttribute|None + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert isinstance(input, LayerOutput) + assert factor_size > 0, "the factor_size must be greater than 0." + + Layer( + inputs=[Input(input.name, **param_attr.attr)], + name=name, + factor_size=factor_size, + type=LayerType.FACTORIZATION_MACHINE, + active_type=act.name, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, LayerType.FACTORIZATION_MACHINE, input, activation=act, size=1) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index a21f67a2d99e7eab39708e2a571d30d7e9f20ce6..10c941f707498ec45e79bed9d3f8054eea19887d 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -11,6 +11,7 @@ test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_l test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer -test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer) +test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer +test_factorization_machine) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_factorization_machine.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_factorization_machine.protostr new file mode 100644 index 0000000000000000000000000000000000000000..4f3002b19942ed58970bfd64e5978c1601273992 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_factorization_machine.protostr @@ -0,0 +1,39 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 1024 + active_type: "" +} +layers { + name: "__factorization_machine_0__" + type: "factorization_machine" + size: 1 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___factorization_machine_0__.w0" + } + factor_size: 10 +} +parameters { + name: "___factorization_machine_0__.w0" + size: 10240 + initial_mean: 0.0 + initial_std: 0.03125 + dims: 1024 + dims: 10 + initial_strategy: 0 + initial_smart: true +} +input_layer_names: "data" +output_layer_names: "__factorization_machine_0__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__factorization_machine_0__" + input_layer_names: "data" + output_layer_names: "__factorization_machine_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_factorization_machine.py b/python/paddle/trainer_config_helpers/tests/configs/test_factorization_machine.py new file mode 100644 index 0000000000000000000000000000000000000000..b249de0fee3c8ca4ad0520872fa2497c493d31b5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_factorization_machine.py @@ -0,0 +1,7 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='data', size=1024) + +fm = factorization_machine(input=data, factor_size=10) + +outputs(fm)