提交 487dc670 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #72 from emailweixu/cos_sim_and_linear_comb

Change cos_sim to use CosSimLayer layer when size=1 and rename convex_comb_layer to linear_comb_layer     
......@@ -245,10 +245,10 @@ addto_layer
:members: addto_layer
:noindex:
convex_comb_layer
linear_comb_layer
-----------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: convex_comb_layer
:members: linear_comb_layer
:noindex:
interpolation_layer
......@@ -280,7 +280,13 @@ tensor_layer
.. automodule:: paddle.trainer_config_helpers.layers
:members: tensor_layer
:noindex:
cos_sim
-------
.. automodule:: paddle.trainer_config_helpers.layers
:members: cos_sim
:noindex:
trans_layer
------------
.. automodule:: paddle.trainer_config_helpers.layers
......@@ -341,12 +347,6 @@ rank_cost
:members: rank_cost
:noindex:
cos_sim
-------
.. automodule:: paddle.trainer_config_helpers.layers
:members: cos_sim
:noindex:
crf_layer
-----------------
.. automodule:: paddle.trainer_config_helpers.layers
......
......@@ -21,18 +21,20 @@ limitations under the License. */
namespace paddle {
/**
* @brief A layer for convex weighted average of vectors,
* @brief A layer for weighted sum of vectors,
* which is used in NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND
* TRANSLATE
* - Input: the first input contains the convex weights (batchSize x weightDim),
* and the shape of second input is (batchSize x (weightdim*dataDim)).
* - Output: the shape of output is (batchSize x dataDim).
* - Input: the the size of the first input is weightDim,
* and the size of the second input is weightdim * dataDim.
* - Output: the sizeof the output is dataDim
* \f[
* out[i][j] = \sum_{j}(in0(i, j) * in1(i,j + i * dataDim)),
* i = 0,1,...,(batchSize-1); j = 0, 1,...,(dataDim-1)
* out(j) = \sum_{i}(in0(i) * in1(i,j + i * dataDim)),
* i = 0,1,...,(weightDim-1); j = 0, 1,...,(dataDim-1)
* \f]
* Note that the above computation is for one sample. Multiple samples are
* processed in one batch.
*
* The config file api is convex_comb_layer.
* The config file api is linear_comb_layer.
*/
class ConvexCombinationLayer : public Layer {
protected:
......
......@@ -48,7 +48,7 @@ void CosSimLayer::forward(PassType passType) {
REGISTER_TIMER_INFO("CosFwAtvTimer", getName().c_str());
MatrixPtr prevOut1 = getInputValue(0);
MatrixPtr prevOut2 = getInputValue(1);
outV->cosSim(*prevOut1, *prevOut2, kCosSimScale_);
outV->cosSim(*prevOut1, *prevOut2, config_.cos_scale());
}
}
......@@ -59,7 +59,7 @@ void CosSimLayer::backward(const UpdateCallback& callback) {
outG->cosSimDerivative(*this->getOutputValue(), *getInputValue(0),
*getInputValue(1), *getInputGrad(0),
*getInputGrad(1), kCosSimScale_);
*getInputGrad(1), config_.cos_scale());
}
}
......
......@@ -36,7 +36,7 @@ namespace paddle {
class CosSimLayer : public Layer {
public:
explicit CosSimLayer(const LayerConfig& config)
: Layer(config), kCosSimScale_(5.0f) {}
: Layer(config) {}
~CosSimLayer() {}
......@@ -44,8 +44,6 @@ public:
void forward(PassType passType);
void backward(const UpdateCallback& callback = nullptr);
const real kCosSimScale_;
};
} // namespace paddle
......@@ -22,6 +22,8 @@ find_python_module(pip REQUIRED)
find_python_module(wheel REQUIRED)
find_python_module(google.protobuf REQUIRED)
add_subdirectory(paddle/trainer_config_helpers/tests)
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/dist/
DESTINATION opt/paddle/share/wheels
)
......@@ -2264,6 +2264,9 @@ class ConvexCombinationLayer(LayerBase):
name, 'convex_comb', size, inputs=inputs, device=device)
config_assert(len(self.inputs) == 2,
'ConvexCombinationLayer must have 2 inputs')
config_assert(
size * self.get_input_layer(0).size == self.get_input_layer(1).size,
'Wrong input size for ConvexCombinationLayer')
self.set_layer_size(size)
@config_layer('interpolation')
......@@ -2313,6 +2316,9 @@ class CosSimVecMatLayer(LayerBase):
self.config.cos_scale = cos_scale
config_assert(len(self.inputs) == 2,
'CosSimVecMatLayer must have 2 inputs')
config_assert(
size * self.get_input_layer(0).size == self.get_input_layer(1).size,
'Wrong input size for CosSimVecMatLayer')
@config_layer('sampling_id')
class SamplingIdLayer(LayerBase):
......@@ -2361,6 +2367,7 @@ class CosSimLayer(LayerBase):
self,
name,
inputs,
cos_scale=5,
device=None):
super(CosSimLayer, self).__init__(
name, 'cos', 1, inputs=inputs, device=device)
......@@ -2368,6 +2375,7 @@ class CosSimLayer(LayerBase):
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
'inputs of CosSimLayer must have same dim')
self.config.cos_scale = cos_scale
@config_layer('tensor')
......
......@@ -47,6 +47,7 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'BaseGeneratedInput', 'conv_operator', 'conv_shift_layer',
'tensor_layer', 'selective_fc_layer', 'sampling_id_layer',
'slope_intercept_layer', 'trans_full_matrix_projection',
'linear_comb_layer',
'convex_comb_layer', 'ctc_layer', 'crf_layer', 'crf_decoding_layer',
'cross_entropy_with_selfnorm', 'cross_entropy',
'multi_binary_label_cross_entropy',
......@@ -70,7 +71,8 @@ class LayerType(object):
POOLING_AVG = 'average'
FC_LAYER = "fc"
COST = 'cost'
COSINE_SIM = 'cos_vm'
COSINE_SIM_VEC = 'cos_vm'
COSINE_SIM = 'cos'
HSIGMOID = 'hsigmoid'
CONV_LAYER = "conv"
POOL_LAYER = "pool"
......@@ -102,7 +104,7 @@ class LayerType(object):
SEL_FC_LAYER = "selective_fc"
SAMPLING_ID_LAYER = "sampling_id"
SLOPE_INTERCEPT_LAYER = "slope_intercept"
CONVEX_COMBINATION_LAYER = "convex_comb"
LINEAR_COMBINATION_LAYER = "convex_comb"
BLOCK_EXPAND = "blockexpand"
CTC_LAYER = "ctc"
......@@ -1171,13 +1173,16 @@ def power_layer(input, weight, name=None, layer_attr=None):
@layer_support()
def scaling_layer(input, weight, name=None, layer_attr=None):
"""
A layer for each row of a matrix, multiplying with a element of a vector.
A layer for multiplying input vector by weight scalar.
.. math::
y.row[i] = w[i] * x.row[i]
y = w x
where :math:`x` is (batchSize x dataDim) input, :math:`w` is
(batchSize x 1) weight vector, and :math:`y` is (batchSize x dataDim) output.
where :math:`x` is size=dataDim input, :math:`w` is size=1 weight,
and :math:`y` is size=dataDim output.
Note that the above computation is for one sample. Multiple samples are
processed in one batch.
The example usage is:
......@@ -1251,11 +1256,14 @@ def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None):
.. math::
similarity = cos(\\theta) = {\\mathbf{a} \\cdot \\mathbf{b}
\\over \\|\\mathbf{b}\\| \\|\\mathbf{b}\\|}
\\over \\|\\mathbf{a}\\| \\|\\mathbf{b}\\|}
The size of a is M, size of b is M*N,
Similarity will be calculated N times by step M. The output size is
N. The scale will be multiplied to similarity.
And the input dimension is :math:`a \in R^M`, :math:`b \in R^{MN}`. The
similarity will be calculated N times by step M. The output dimension is
:math:`R^N`. The scale will be multiplied to similarity.
Note that the above computation is for one sample. Multiple samples are
processed in one batch.
:param name: layer name
:type name: basestring
......@@ -1272,14 +1280,23 @@ def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
type=LayerType.COSINE_SIM,
size=size,
cos_scale=scale,
inputs=[a.name, b.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
if size == 1:
Layer(
name=name,
type=LayerType.COSINE_SIM,
cos_scale=scale,
inputs=[a.name, b.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
else:
Layer(
name=name,
type=LayerType.COSINE_SIM_VEC,
size=size,
cos_scale=scale,
inputs=[a.name, b.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b])
@wrap_name_default()
......@@ -2911,29 +2928,37 @@ def slope_intercept_layer(input, name=None, slope=1.0, intercept=0.0):
@wrap_name_default()
def convex_comb_layer(input, size, name=None):
def linear_comb_layer(weights, vectors, size, name=None):
"""
A layer for convex weighted average of vectors takes two inputs.
- Input: a vector containing the convex weights (batchSize x weightdim),
and a matrix in a vector form (batchSize x (weightdim * datadim)).
- Output: a vector (batchSize * datadim).
A layer for weighted sum of vectors takes two inputs.
- Input: size of weights is M
size of vectors is M*N
- Output: a vector of size=N
.. math::
y[i][j] = \sum_{j}(x_{1}(i, j) * x_{2}(i,j + i * dataDim)),
z(i) = \sum_{j=0}^{M-1} x(j) y(i+Nj)
where :math:`0 \le i \le N-1`
Or in the matrix notation:
.. math::
i = 0,1,...,(batchSize-1); j = 0, 1,...,(dataDim-1)
z = x^T Y
In this formular:
- :math:`x_{1}`: the first input.
- :math:`x_{2}`: the second input.
- :math:`y`: the output.
- :math:`x`: weights
- :math:`y`: vectors.
- :math:`z`: the output.
Note that the above computation is for one sample. Multiple samples are
processed in one batch.
The simple usage is:
.. code-block:: python
convex_comb = convex_comb_layer(input=inputs,
linear_comb = linear_comb_layer(weighs=weight, vectors=vectors,
size=elem_dim)
:param input: The input layers.
......@@ -2946,15 +2971,16 @@ def convex_comb_layer(input, size, name=None):
:rtype: LayerOutput
"""
assert isinstance(input, list) or isinstance(input, tuple)
assert len(input) == 2
Layer(
name=name,
type=LayerType.CONVEX_COMBINATION_LAYER,
type=LayerType.LINEAR_COMBINATION_LAYER,
size=size,
inputs=[Input(input[0].name), Input(input[1].name)],
inputs=[Input(weights.name), Input(vectors.name)],
)
return LayerOutput(name, LayerType.CONVEX_COMBINATION_LAYER, input, size=size)
return LayerOutput(name, LayerType.LINEAR_COMBINATION_LAYER,
[weights, vectors], size=size)
convex_comb_layer = linear_comb_layer
@wrap_name_default()
def block_expand_layer(input,
......
#################### test_config_parser #########################
add_test(NAME layers_test
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
python ${PROJ_ROOT}/python/paddle/trainer_config_helpers/tests/layers_test.py
WORKING_DIRECTORY ${PROJ_ROOT}/python/paddle)
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.config_parser import parse_config_and_serialize
if __name__ == '__main__':
parse_config_and_serialize(
'trainer_config_helpers/tests/layers_test_config.py', '')
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
num_classes = 5
x = data_layer(name="input1", size=3)
y = data_layer(name="input2", size=5)
x1 = fc_layer(input=x, size=5)
y1 = fc_layer(input=y, size=5)
y2 = fc_layer(input=y, size=15)
cos1 = cos_sim(a=x1, b=y1)
cos3 = cos_sim(a=x1, b=y2, size=3)
linear_comb = linear_comb_layer(weights=x1, vectors=y2, size=3)
out = fc_layer(input=[cos1, cos3, linear_comb],
size=num_classes,
act=SoftmaxActivation())
outputs(classification_cost(out, data_layer(name="label", size=num_classes)))
settings(
batch_size=10,
learning_rate=2e-3,
learning_method=AdamOptimizer(),
regularization=L2Regularization(8e-4),
gradient_clipping_threshold=25
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册