提交 5392a503 编写于 作者: W wangmeng28

Merge remote-tracking branch 'upstream/develop' into factorization_machine_layer

......@@ -382,6 +382,11 @@ cos_sim
.. autoclass:: paddle.v2.layer.cos_sim
:noindex:
l2_distance
-----------
.. autoclass:: paddle.v2.layer.l2_distance
:noindex:
trans
-----
.. autoclass:: paddle.v2.layer.trans
......
......@@ -513,19 +513,14 @@ ParamGradInfoMap AppendBackward(
const int root_block_idx = 0;
auto root_block = program_desc.MutableBlock(root_block_idx);
// insert fill one op for target
// TODO(qiao) add some check to the target.
std::string fill_one_op_out = GradVarName(target.Name());
std::vector<int64_t> target_shape_desc = target.Shape();
std::vector<int> target_shape;
std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape),
[](int64_t dim) { return static_cast<int>(dim); });
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape},
{{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)},
{"data_type", target.GetDataType()}}));
// infer var type of fill_one_op
......
......@@ -508,6 +508,7 @@ TEST(Backward, simple_single_op) {
op->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out");
target.SetShape({1});
auto var_to_grad = AppendBackward(program, target, {});
ASSERT_EQ(block->AllOps().size(), 3UL);
......@@ -544,6 +545,7 @@ TEST(Backward, default_attribute) {
op->CheckAttrs();
auto target = f::VarDescBind("out");
target.SetShape({1});
AppendBackward(program, target, {});
ASSERT_EQ(block->AllOps().size(), 3UL);
......@@ -581,6 +583,7 @@ TEST(Backward, simple_mult_op) {
op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {});
......@@ -670,6 +673,7 @@ TEST(Backward, intermedia_var_no_grad) {
op4->SetOutput("Out", {"out4"});
auto target = f::VarDescBind("out4");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"out3"});
......@@ -730,6 +734,7 @@ TEST(Backward, var_no_grad) {
op2->SetOutput("Z", {"z2"});
auto target = f::VarDescBind("z2");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"z1"});
......@@ -810,6 +815,7 @@ TEST(Backward, shared_var) {
op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {});
......@@ -888,6 +894,7 @@ TEST(Backward, half_backward) {
op1->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out");
target.SetShape({1});
size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"b"});
f::OpDescBind *fill_op = block->AllOps()[forward_len];
......
......@@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) {
return typeid(int);
case DataType::INT64:
return typeid(int64_t);
case DataType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
}
......@@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) {
case DataType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
visitor.template operator()<bool>();
break;
default:
PADDLE_THROW("Not supported");
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "L2DistanceLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(l2_distance, L2DistanceLayer);
bool L2DistanceLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and "
<< "only two inputs.";
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer "
<< "is fixed to be 1.";
return true;
}
void L2DistanceLayer::forward(PassType passType) {
Layer::forward(passType);
const auto inV1 = getInputValue(0);
const auto inV2 = getInputValue(1);
CHECK(inV1 && inV2);
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
<< "The height of two inputs of this layer must be the same.";
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
<< "The width of two inputs of this layer must be the same.";
int batchSize = inV1->getHeight();
int output_dim = getSize();
{
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
reserveOutput(batchSize, output_dim);
auto outV = getOutputValue();
CHECK(outV) << "The output matrix should not be null.";
Matrix::resizeOrCreate(
inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_);
inputSub_->assign(*inV1);
inputSub_->sub(*inV2);
outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0);
outV->sqrt2(*outV);
}
}
void L2DistanceLayer::backward(const UpdateCallback& callback) {
const auto outG = getOutputGrad();
const auto outV = getOutputValue();
CHECK(outG && outV);
auto inGrad1 = getInputGrad(0);
auto inGrad2 = getInputGrad(1);
{
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
if (inGrad1 || inGrad2) {
outV->scalarDiv(*outV, 1.);
outV->dotMul(*outG, *outV);
}
if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV);
if (inGrad2) {
inputSub_->mulScalar(-1.);
inGrad2->addRowScale(0, *inputSub_, *outV);
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief The layer calculates the l2 distance between two input vectors.
* \f[
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
* \f]
*
* - Input1: A vector (batchSize * dataDim)
* - Input2: A vector (batchSize * dataDim)
* - Output: A vector (batchSize * 1)
*
* The configuration api is: l2_distance_layer.
*/
class L2DistanceLayer : public Layer {
public:
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
~L2DistanceLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
private:
// Store the result of subtracting Input2 from Input1 in forward computation,
// which will be reused in backward computation.
MatrixPtr inputSub_;
};
} // namespace paddle
......@@ -583,6 +583,7 @@ TEST(Layer, maxoutLayer) {
testLayerGrad(config, "maxout", 10, false, useGpu);
}
}
void testFcLayer(string format, size_t nnz) {
TestConfig config;
config.biasSize = 1024;
......@@ -2444,6 +2445,25 @@ TEST(Layer, ScaleSubRegionLayer) {
}
}
TEST(Layer, L2DistanceLayer) {
TestConfig config;
config.layerConfig.set_type("l2_distance");
config.layerConfig.set_size(1);
config.biasSize = 0;
const size_t input_dim = 27;
const size_t batch_size = 11;
config.inputDefs.push_back({INPUT_DATA, "layer_0", input_dim, 0});
config.inputDefs.push_back({INPUT_DATA, "layer_1", input_dim, 0});
config.layerConfig.add_inputs();
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "l2_distance", batch_size, false, useGpu);
}
}
void testFactorizationMachineLayer(InputType type, bool useGpu) {
const int FACTOR_SIZE = 10;
TestConfig config;
......
......@@ -87,6 +87,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif()
if ("${TARGET}" STREQUAL "logical_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(logical_and);\n")
endif()
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/logical_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BinaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) Left hand operand of %s operator",
comment.type));
AddInput("Y",
string::Sprintf("(LoDTensor) Right hand operand of %s operator",
comment.type));
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
}
};
template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
UnaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
comment.type));
AddOutput("Out", string::Sprintf(
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
}
};
template <typename OpComment>
class BinaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type);
PADDLE_ENFORCE(context->HasInput("Y"),
"Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
"The number of elements in X and Y should be same");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment>
class UnaryLogicalOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
class LogicalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y");
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y");
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X");
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)");
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
paddle::operators::LogicalXorFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/logical_op.h"
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU,
paddle::operators::LogicalXorFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
struct LogicalAndFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; }
};
template <typename T>
struct LogicalOrFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; }
};
template <typename T>
struct LogicalNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a) const { return !a; }
};
template <typename T>
struct LogicalXorFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return (a || b) && !(a && b);
}
};
template <typename Place, typename Functor>
class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
binary_func);
}
};
template <typename Place, typename Functor>
class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
out->mutable_data<bool>(context.GetPlace()), unary_func);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##Place, functor<bool>>);
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##Place, functor<bool>>);
......@@ -3342,6 +3342,20 @@ class RowL2NormLayer(LayerBase):
self.set_layer_size(input_layer.size)
@config_layer('cos')
class CosSimLayer(LayerBase):
def __init__(self, name, inputs, cos_scale=1, device=None):
super(CosSimLayer, self).__init__(
name, 'cos', 1, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 2,
'The CosSimLayer expects two and only two inputs.')
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
'The two inputs of CosSimLayer must have the same dimensionality.')
self.config.cos_scale = cos_scale
@config_layer('cos_vm')
class CosSimVecMatLayer(LayerBase):
def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
......@@ -3349,10 +3363,24 @@ class CosSimVecMatLayer(LayerBase):
name, 'cos_vm', size, inputs=inputs, device=device)
self.config.cos_scale = cos_scale
config_assert(
len(self.inputs) == 2, 'CosSimVecMatLayer must have 2 inputs')
len(self.inputs) == 2, 'The CosSimVecMatLayer must have 2 inputs.')
config_assert(
size * self.get_input_layer(0).size == self.get_input_layer(1).size,
'Wrong input size for CosSimVecMatLayer')
'Wrong input size for CosSimVecMatLayer.')
@config_layer('l2_distance')
class L2DistanceLayer(LayerBase):
def __init__(self, name, inputs, device=None):
super(L2DistanceLayer, self).__init__(
name, 'l2_distance', 1, inputs=inputs, device=device)
config_assert(
len(self.inputs) == 2, ('The L2DistanceLayer must have '
'and only have 2 inputs.'))
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
('Two inputs of the L2DistanceLayer must have '
'the same dimensionality.'))
@config_layer('sampling_id')
......@@ -3396,18 +3424,6 @@ class AverageLayer(LayerBase):
self.create_bias_parameter(bias, self.config.size)
@config_layer('cos')
class CosSimLayer(LayerBase):
def __init__(self, name, inputs, cos_scale=1, device=None):
super(CosSimLayer, self).__init__(
name, 'cos', 1, inputs=inputs, device=device)
config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs')
config_assert(
self.get_input_layer(0).size == self.get_input_layer(1).size,
'inputs of CosSimLayer must have same dim')
self.config.cos_scale = cos_scale
@config_layer('tensor')
class TensorLayer(LayerBase):
def __init__(self, name, size, inputs, bias=True, **xargs):
......
......@@ -51,6 +51,7 @@ __all__ = [
'last_seq',
'first_seq',
'cos_sim',
'l2_distance_layer',
'hsigmoid',
'conv_projection',
'square_error_cost',
......@@ -169,6 +170,7 @@ class LayerType(object):
COST = 'cost'
COSINE_SIM_VEC = 'cos_vm'
COSINE_SIM = 'cos'
L2_DISTANCE = 'l2_distance'
HSIGMOID = 'hsigmoid'
CONV_LAYER = 'conv'
CONVTRANS_LAYER = 'convt'
......@@ -2337,6 +2339,51 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b], size=size)
@wrap_name_default()
@layer_support()
def l2_distance_layer(x, y, name=None, layer_attr=None):
"""
This layer calculates and returns the Euclidean distance between two input
vectors x and y. The equation is as follows:
.. math::
l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)}
The output size of this layer is fixed to be 1. Note that the above
computation is for one sample. Multiple samples are processed in one batch.
The example usage is:
.. code-block:: python
l2_sim = l2_distance(x=layer1, y=layer2)
:param name: The name of this layer. It is optional.
:type name: basestring
:param x: The first input x for this layer, whose output is a matrix with
dimensionality N x D. N is the sample number in a mini-batch.
D is the dimensionality of x's output.
:type x: LayerOutput
:param y: The second input y for this layer, whose output is a matrix with
dimensionality N x D. N is the sample number in a mini-batch.
D is the dimensionality of y's output.
:type y: LayerOutput
:param layer_attr: The extra layer attributes, for example, drop rate.
See ExtraLayerAttribute for more details.
:type layer_attr: ExtraLayerAttribute
:return: The returned LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(x, LayerOutput) and isinstance(y, LayerOutput)
Layer(
name=name,
type=LayerType.L2_DISTANCE,
inputs=[x.name, y.name],
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1)
@wrap_name_default()
@wrap_bias_attr_default(has_bias=True)
@wrap_param_attr_default()
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from activations import LinearActivation, ReluActivation, SoftmaxActivation, \
IdentityActivation, TanhActivation, SequenceSoftmaxActivation
......@@ -26,9 +26,9 @@ __all__ = [
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
'simple_attention', 'dot_product_attention', 'simple_gru2',
'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs',
'outputs'
'simple_attention', 'dot_product_attention', 'multi_head_attention',
'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm',
'inputs', 'outputs'
]
######################################################
......@@ -1476,10 +1476,8 @@ def dot_product_attention(encoded_sequence,
expand_as=encoded_sequence,
name='%s_expand' % name)
m = linear_comb_layer(
weights=expanded,
vectors=encoded_sequence,
name='%s_dot-product' % name)
m = dot_prod_layer(
input1=expanded, input2=encoded_sequence, name='%s_dot-product' % name)
attention_weight = fc_layer(
input=m,
......@@ -1498,6 +1496,134 @@ def dot_product_attention(encoded_sequence,
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
@wrap_name_default()
def multi_head_attention(query,
key,
value,
key_proj_size,
value_proj_size,
head_num,
attention_type,
softmax_param_attr=None,
name=None):
"""
Calculate and return a context vector with dot-product attention mechanism.
The dimension of the context vector equals to value_proj_size * head_num.
Please refer to **Attention Is All You Need** for more details. The link is
as follows:
https://arxiv.org/abs/1706.03762.
The example usage is:
.. code-block:: python
context = multi_head_attention(query=decoder_state,
key=enc_seq,
value=enc_seq,
key_proj_size=64,
value_pro_size=64,
head_num=8,
attention_type='dot-product attention')
:param name: A prefix attached to the name of each layer that defined inside
the multi_head_attention.
:type name: basestring
:param softmax_param_attr: The parameter attribute of sequence softmax
that is used to produce attention weight.
:type softmax_param_attr: ParameterAttribute
:param query: query is used to calculate attention weights over values at current step.
:type query: LayerOutput
:param key: key is used to calculate the attention weight of the corresponding value.
:type key: LayerOutput
:param value: value is the sequence to be attended.
:type value: LayerOutput
:param key_proj_size: The dimension of the linear projection performed on key and query.
:type key_proj_size: int
:param value_proj_size: The dimension of the linear projection performed on value.
:type value_proj_size: int
:param head_num: The number of attention heads.
:type head_num: int
:param attention_type: The type of the attention mechanism used in each attention
heads. Now, we only support scaled dot-product attention and
additive attention.
:type attention_type: basestring
:return: The context vector.
:rtype: LayerOutput
"""
assert attention_type in ['dot-product attention', 'additive attention']
with mixed_layer(
size=key_proj_size * head_num,
name='%s_query_proj' % name) as query_proj:
query_proj += full_matrix_projection(query)
query_proj = expand_layer(input=query_proj, expand_as=key)
with mixed_layer(
size=key_proj_size * head_num,
name='%s_key_proj' % name) as key_proj:
key_proj += full_matrix_projection(key)
with mixed_layer(
size=value_proj_size * head_num,
name='%s_value_proj' % name) as value_proj:
value_proj += full_matrix_projection(value)
head_list = []
for i in range(head_num):
with mixed_layer(size=key_proj_size) as sub_query_proj:
sub_query_proj += identity_projection(
query_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=key_proj_size) as sub_key_proj:
sub_key_proj += identity_projection(
key_proj, offset=key_proj_size * i, size=key_proj_size)
with mixed_layer(size=value_proj_size) as sub_value_proj:
sub_value_proj += identity_projection(
value_proj, offset=value_proj_size * i, size=value_proj_size)
if attention_type == 'dot-product attention':
m = dot_prod_layer(
input1=sub_query_proj,
input2=sub_key_proj,
name='%s_dot-product_%d' % (name, i))
m = slope_intercept_layer(
input=m,
slope=math.sqrt(1.0 / key_proj_size),
name='%s_dot-product_scaling_%d' % (name, i))
else:
with mixed_layer(
size=key_proj_size,
act=TanhActivation(),
name='%s_combine_%d' % (name, i)) as m:
m += identity_projection(sub_query_proj)
m += identity_projection(sub_key_proj)
attention_weight = fc_layer(
input=m,
size=1,
act=SequenceSoftmaxActivation(),
param_attr=softmax_param_attr,
name="%s_softmax_%d" % (name, i),
bias_attr=False)
scaled = scaling_layer(
weight=attention_weight,
input=sub_value_proj,
name='%s_scaling_%d' % (name, i))
head = pooling_layer(
input=scaled,
pooling_type=SumPooling(),
name="%s_pooling_%d" % (name, i))
head_list.append(head)
attended = concat_layer(head_list)
return attended
def inputs(layers, *args):
"""
Declare the inputs of network. The order of input should be as same as
......
......@@ -10,7 +10,8 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer
test_dot_prod_layer test_factorization_machine)
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer
test_factorization_machine)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "x"
type: "data"
size: 128
active_type: ""
}
layers {
name: "y"
type: "data"
size: 128
active_type: ""
}
layers {
name: "__l2_distance_layer_0__"
type: "l2_distance"
size: 1
active_type: ""
inputs {
input_layer_name: "x"
}
inputs {
input_layer_name: "y"
}
}
input_layer_names: "x"
input_layer_names: "y"
output_layer_names: "__l2_distance_layer_0__"
sub_models {
name: "root"
layer_names: "x"
layer_names: "y"
layer_names: "__l2_distance_layer_0__"
input_layer_names: "x"
input_layer_names: "y"
output_layer_names: "__l2_distance_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
outputs(
l2_distance_layer(
x=data_layer(
name='x', size=128), y=data_layer(
name='y', size=128)))
import op_test
import unittest
import numpy as np
def create_test_class(op_type, callback, binary_op=True):
class Cls(op_test.OpTest):
def setUp(self):
a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
if binary_op:
b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
c = callback(a, b)
else:
c = callback(a)
self.outputs = {'Out': c}
self.op_type = op_type
if binary_op:
self.inputs = {'X': a, 'Y': b}
else:
self.inputs = {'X': a}
def test_output(self):
self.check_output()
Cls.__name__ = op_type
globals()[op_type] = Cls
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b))
if __name__ == '__main__':
unittest.main()
......@@ -16,14 +16,18 @@ class TestOptimizer(unittest.TestCase):
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts = sgd_optimizer.minimize(mul_out, init_program)
opts = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 1)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
......@@ -44,12 +48,16 @@ class TestOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
global_step = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="step")
learning_rate = 0.01
sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=learning_rate, global_step=global_step)
opts = sgd_optimizer.minimize(mul_out, init_program)
opts = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 2)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
......@@ -90,7 +98,11 @@ class TestMomentumOptimizer(unittest.TestCase):
learning_rate = 0.01
momentum_optimizer = self.MockMomentum(
learning_rate=learning_rate, momentum=0.2)
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
......@@ -132,10 +144,14 @@ class TestMomentumOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
momentum_optimizer = self.MockMomentum(
learning_rate=learning_rate, momentum=0.2, use_nesterov=True)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
......@@ -186,10 +202,14 @@ class TestAdagradOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adagrad_optimizer = self.MockAdagrad(
learning_rate=learning_rate, epsilon=1.0e-6)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -242,10 +262,14 @@ class TestAdamOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adam_optimizer = self.MockAdam(
learning_rate=learning_rate, beta1=0.9, beta2=0.999)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -300,10 +324,14 @@ class TestAdamaxOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
adamax_optimizer = self.MockAdamax(
learning_rate=learning_rate, beta1=0.9, beta2=0.999)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out,
......@@ -355,10 +383,14 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
learning_rate = 0.01
decayed_adagrad_optimizer = self.MockDecayedAdagrad(
learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6)
params_grads = append_backward_ops(mul_out)
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass(
......
import unittest
import paddle.v2.fluid.core as core
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.framework import g_main_program
......@@ -98,21 +97,26 @@ class TestProgram(unittest.TestCase):
"Y": add_y},
outputs={"Out": add_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": add_out}, outputs={"Out": mean_out})
self.assertEqual(mul_op.idx, 0)
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(add_out, set())
param_to_grad = prog.append_backward(mean_out, set())
def grad_name(name):
return name + "@GRAD"
for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out"):
for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out",
"mean.out"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
expect_ops = [
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad"
"mul", "elementwise_add", "mean", "fill_constant", "mean_grad",
"elementwise_add_grad", "mul_grad"
]
actual_ops = []
for op in block.ops:
......
......@@ -29,7 +29,11 @@ class TestL2DecayRegularizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
......@@ -62,7 +66,11 @@ class TestL1DecayRegularizer(unittest.TestCase):
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
params_grads = append_backward_ops(mean_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册