提交 1644c72a 编写于 作者: W wangmeng28

Add framework of the factorization machine layer

上级 3f874143
......@@ -54,7 +54,7 @@ img_conv
.. _api_v2.layer_context_projection:
context_projection
context_projection
------------------
.. autoclass:: paddle.v2.layer.context_projection
:noindex:
......@@ -70,7 +70,7 @@ Image Pooling Layer
img_pool
--------
.. autoclass:: paddle.v2.layer.img_pool
:noindex:
:noindex:
spp
---
......@@ -99,7 +99,7 @@ sum_to_one_norm
---------------
.. autoclass:: paddle.v2.layer.sum_to_one_norm
:noindex:
cross_channel_norm
------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm
......@@ -109,7 +109,7 @@ row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers
================
......@@ -395,6 +395,13 @@ multiplex
.. autoclass:: paddle.v2.layer.multiplex
:noindex:
Factorization Machine Layer
============================
factorization_machine
---------------------
.. autoclass:: paddle.v2.layer.factorization_machine
:noindex:
Slicing and Joining Layers
==========================
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "FactorizationMachineLayer.h"
#include <algorithm>
#include <vector>
#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(factorization_machine, FactorizationMachineLayer);
bool FactorizationMachineLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
factorSize_ = config_.factor_size();
/* initialize the latentVectors_ */
CHECK_EQ(inputLayers_.size(), 1UL);
size_t height = inputLayers_[0]->getSize();
latentVectors_.reset(new Weight(height, factorSize_, parameters_[0]));
return true;
}
void FactorizationMachineLayer::forward(PassType passType) {
Layer::forward(passType);
auto input = getInput(0);
int batchSize = input.getBatchSize();
int size = getSize();
reserveOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
/* activation */ {
REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str());
forwardActivation();
}
}
void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Do derivation */ {
REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str());
backwardActivation();
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ThreadLocal.h"
namespace paddle {
/**
* @brief The Factorization Machine models pairwise (order-2) feature
* interactions as inner product of the learned latent vectors corresponding
* to each input feature.
*
* The Factorization Machine can effectively capture feature interactions
* especially when the input is sparse. While in principle FM can model higher
* order feature interaction, in practice usually only order-2 feature
* interactions are considered. The Factorization Machine Layer here only
* computes the order-2 interations with the formula:
*
* \f[
* y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j
* \f]
*
* The config file api is factorization_machine.
*/
class FactorizationMachineLayer : public Layer {
protected:
/// The latent vectors, shape: (size, factorSize_)
std::unique_ptr<Weight> latentVectors_;
/// The hyperparameter that defines the dimensionality of the factorization
size_t factorSize_;
public:
explicit FactorizationMachineLayer(const LayerConfig& config)
: Layer(config) {}
~FactorizationMachineLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -2359,6 +2359,25 @@ TEST(Layer, ScaleShiftLayer) {
}
}
void testFactorizationMachineLayer(InputType type, bool useGpu) {
const int FACTOR_SIZE = 10;
TestConfig config;
config.layerConfig.set_type("factorization_machine");
config.layerConfig.set_factor_size(FACTOR_SIZE);
config.biasSize = 1;
config.inputDefs.push_back({type, "layer_0", 8192, 0});
config.layerConfig.add_inputs();
testLayerGrad(config, "factorization_machine", 16, false, useGpu, false);
}
TEST(Layer, FactorizationMachineLayer) {
testFactorizationMachineLayer(INPUT_DATA, false);
testFactorizationMachineLayer(INPUT_SPARSE_FLOAT_VALUE_DATA, false);
#ifdef PADDLE_WITH_CUDA
testFactorizationMachineLayer(INPUT_DATA, true);
#endif
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -525,6 +525,9 @@ message LayerConfig {
// for switch order layer
optional ReshapeConfig reshape_conf = 59;
// for factorization machine layer
optional uint32 factor_size = 60;
}
message EvaluatorConfig {
......
......@@ -3780,6 +3780,21 @@ class SwitchOrderLayer(LayerBase):
self.config.reshape_conf.width_axis.extend(reshape['width'])
@config_layer('factorization_machine')
class FactorizationMachineLayer(LayerBase):
def __init__(self, name, inputs, factor_size, **xargs):
super(FactorizationMachineLayer, self).__init__(
name, 'factorization_machine', size=1, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'factorization machine layer must have one and only one input.')
self.config.factor_size = factor_size
input_layer = self.get_input_layer(0)
psize = input_layer.size * factor_size
dims = [input_layer.size, 1]
self.create_input_parameter(0, psize, dims)
# Deprecated, use a new layer specific class instead
@config_func
def Layer(name, type, **xargs):
......
......@@ -143,6 +143,7 @@ __all__ = [
'scale_shift_layer',
'img_conv3d_layer',
'resize_layer',
'factorization_machine',
]
......@@ -253,6 +254,8 @@ class LayerType(object):
RESIZE = 'resize'
FACTORIZATION_MACHINE = 'factorization_machine'
@staticmethod
def is_layer_type(type_name):
"""
......@@ -6955,3 +6958,65 @@ def resize_layer(input, size, name=None):
"""
Layer(name=name, type=LayerType.RESIZE, inputs=Input(input.name), size=size)
return LayerOutput(name, LayerType.RESIZE, parents=[input], size=input.size)
@wrap_name_default()
@wrap_act_default(act=LinearActivation())
@wrap_param_attr_default()
@layer_support()
def factorization_machine(input,
factor_size,
act=None,
name=None,
param_attr=None,
layer_attr=None):
"""
The Factorization Machine models pairwise feature interactions as inner
product of the learned latent vectors corresponding to each input feature.
The Factorization Machine can effectively capture feature interactions
especially when the input is sparse. In practice, usually order 2 feature
interactions are considered using Factorization Machine with the formula:
.. math::
y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j
Note:
X is the input vector with size n. V is the factor matrix. Each row of V
is the latent vector corresponding to each input dimesion. The size of
each latent vector is k.
.. code-block:: python
factor_machine = factorization_machine(input=input_layer, factor_size=10)
:param input: The input layer.
:type input: LayerOutput
:param factor_size: The hyperparameter that defines the dimensionality of
the latent vector size
:type context_len: int
:param act: Activation Type. Default is linear activation.
:type act: BaseActivation
:param param_attr: The Parameter Attribute. If None, the latent vectors will
be initialized smartly. It's better to set it by
yourself.
:type param_attr: ParameterAttribute
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
assert factor_size > 0, "the factor_size must be greater than 0."
Layer(
inputs=[Input(input.name, **param_attr.attr)],
name=name,
factor_size=factor_size,
type=LayerType.FACTORIZATION_MACHINE,
active_type=act.name,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.FACTORIZATION_MACHINE, input, activation=act, size=1)
......@@ -10,6 +10,7 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer)
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
test_factorization_machine)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 1024
active_type: ""
}
layers {
name: "__factorization_machine_0__"
type: "factorization_machine"
size: 1
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___factorization_machine_0__.w0"
}
factor_size: 10
}
parameters {
name: "___factorization_machine_0__.w0"
size: 10240
initial_mean: 0.0
initial_std: 0.03125
dims: 1024
dims: 1
initial_strategy: 0
initial_smart: true
}
input_layer_names: "data"
output_layer_names: "__factorization_machine_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__factorization_machine_0__"
input_layer_names: "data"
output_layer_names: "__factorization_machine_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
data = data_layer(name='data', size=1024)
fm = factorization_machine(input=data, factor_size=10)
outputs(fm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册