diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2b8b82e74fc49d454b5331460acbffd0e9404fb5..db17cd004b5a393bb6b6af50b3169e39fa9b3a96 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -177,6 +177,8 @@ paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, k paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) +paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8127e554bed1aae7a5ce8837bcadf1b7f13f1ac2 --- /dev/null +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -0,0 +1,97 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/add_position_encoding_op.h" + +namespace paddle { +namespace operators { + +class AddPositionEncodingOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of add_position_encoding_op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Out(Output) of add_position_encoding_op should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Out must not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Out@GRAD must not be null."); + + auto out_dims = ctx->GetInputDim("Out"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), out_dims); + } + } +}; + +class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of AddPositionEncoding operator"); + AddOutput("Out", "Output of AddPositionEncoding operator"); + AddAttr("alpha", "The scale of Original Embedding.") + .SetDefault(1.0f) + .AddCustomChecker([](const float& alpha) { + PADDLE_ENFORCE(alpha >= 0.0f, "'alpha' must be above 0.0."); + }); + AddAttr("beta", "The scale of Position Embedding.") + .SetDefault(1.0f) + .AddCustomChecker([](const float& beta) { + PADDLE_ENFORCE(beta >= 0.0f, "'beta' must be between 0.0."); + }); + AddComment(R"DOC( + Add Position Encoding Operator. + + The add position encoding calculates the output based on the input, alpha, beta. + The size of each dimension of the parameters checked in the infer-shape. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plt = paddle::platform; + +REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp, + ops::AddPositionEncodingOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad); + +REGISTER_OP_CPU_KERNEL( + add_position_encoding, + ops::AddPositionEncodingKernel, + ops::AddPositionEncodingKernel); + +REGISTER_OP_CPU_KERNEL( + add_position_encoding_grad, + ops::AddPositionEncodingGradKernel, + ops::AddPositionEncodingGradKernel); diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5f371235f160c416058e877dbba2d9fe89abf7db --- /dev/null +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" + +namespace paddle { +namespace operators { + +template +class AddPositionEncodingKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto& x_lod = X->lod(); + auto* src_ptr = X->data(); + + auto* Out = context.Output("Out"); + auto* dst_ptr = Out->mutable_data(context.GetPlace()); + + float alpha = context.Attr("alpha"); + float beta = context.Attr("beta"); + + auto x_dim = X->dims(); + int batch_size = 0; + int max_seq_len = 0; + int enc_size = 0; + + if (x_lod.empty()) { + PADDLE_ENFORCE( + x_dim.size() == 3UL, + "The input X of Add Position Encoding should be 3-D Tensor!"); + batch_size = x_dim[0]; + max_seq_len = x_dim[1]; + enc_size = x_dim[2]; + } else { + PADDLE_ENFORCE( + x_dim.size() == 2UL, + "The input X of Add Position Encoding should be 2-D LoDTensor!"); + PADDLE_ENFORCE( + x_lod.size() == 1UL, + "The Add Position Encoding Op only supports lod_level == 1!"); + batch_size = x_lod[0].size() - 1; + max_seq_len = -1; + enc_size = x_dim[1]; + } + + PADDLE_ENFORCE(enc_size % 2 == 0, "Only support even encode size!"); + + const int half_size = enc_size / 2; + for (int i = 0; i < batch_size; ++i) { + const int max_length = + x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i]; + for (int j = 0; j < max_length; ++j) { + for (int k = 0; k < half_size; ++k) { + const double val = (half_size > 1) + ? j / pow(10000.0, double(k) / (half_size - 1)) + : j / 10000.0; + dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta; + dst_ptr[half_size + k] = + src_ptr[half_size + k] * alpha + cos(val) * beta; + } + src_ptr += enc_size; + dst_ptr += enc_size; + } + } + } +}; + +template +class AddPositionEncodingGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = + context.Input(framework::GradVarName("Out")); + auto dout = framework::EigenVector::Flatten(*dOut); + + auto* dX = + context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + auto dx = framework::EigenVector::Flatten(*dX); + + float alpha = context.Attr("alpha"); + + auto* place = + context.template device_context().eigen_device(); + dx.device(*place) = dout * static_cast(alpha); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4bfa89d9facf1d368e3018a248dc090c81c3402e..7fd616dbf65fa028badd28c0bf6c6f150c2cc61b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -157,6 +157,8 @@ __all__ = [ 'sequence_reverse', 'affine_channel', 'hash', + 'log_loss', + 'add_position_encoding', ] @@ -7580,3 +7582,99 @@ def hash(input, hash_size, num_hash=1, name=None): attrs={'num_hash': num_hash, 'mod_by': hash_size}) return out + + +def log_loss(input, label, epsilon=1e-4, name=None): + """ + **Negative Log Loss Layer** + + This layer accepts input predictions and target label and returns the + negative log loss. + + .. math:: + + Out = -label * \\log{(input + \\epsilon)} + - (1 - label) * \\log{(1 - input + \\epsilon)} + + Args: + input (Variable|list): a 2-D tensor with shape [N x 1], where N is the + batch size. This input is a probability computed + by the previous operator. + label (Variable|list): the ground truth which is a 2-D tensor with + shape [N x 1], where N is the batch size. + epsilon (float): epsilon + name (string): the name of log_loss + + Returns: + Variable: A 2-D tensor with shape [N x 1], the negative log loss. + + Examples: + .. code-block:: python + + prob = fluid.layers.sigmoid(net) + cost = fluid.layers.log_loss(input=prob, label=label) + """ + helper = LayerHelper('log_loss', **locals()) + + if name is None: + loss = helper.create_variable_for_type_inference(dtype=input.dtype) + else: + loss = helper.create_variable( + name=name, dtype=input.dtype, persistable=False) + + helper.append_op( + type='log_loss', + inputs={'Predicted': [input], + 'Labels': [label]}, + outputs={'Loss': [loss]}, + attrs={'epsilon': epsilon}) + return loss + + +def add_position_encoding(input, alpha, beta, name=None): + """ + **Add Position Encoding Layer** + + This layer accepts an input 3D-Tensor of shape [N x M x P], and return an + output Tensor of shape [N x M x P] with positional encoding value. + + Refer to `Attention Is All You Need`_ . + + .. math:: + PE(pos, 2i) = \\sin{(pos / 10000^{2i / P})} \\\\ + PE(pos, 2i + 1) = \\cos{(pos / 10000^{2i / P})} \\\\ + Out(:, pos, i) = \\alpha * input(:, pos, i) + \\beta * PE(pos, i) + + Where: + * PE(pos, 2i): the increment for the number at even position + * PE(pos, 2i + 1): the increment for the number at odd position + + Args: + input (Variable): 3-D input tensor with shape [N x M x P] + alpha (float): multiple of Input Tensor + beta (float): multiple of Positional Encoding Tensor + name (string): the name of position encoding layer + + Returns: + Variable: A 3-D Tensor of shape [N x M x P] with positional encoding. + + Examples: + .. code-block:: python + + position_tensor = fluid.layers.add_position_encoding(input=tensor) + """ + helper = LayerHelper('add_position_encoding', **locals()) + dtype = helper.input_dtype() + + if name is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable(name=name, dtype=dtype, persistable=False) + + helper.append_op( + type="add_position_encoding", + inputs={"X": input}, + outputs={"Out": out}, + attrs={"alpha": alpha, + "beta": beta}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_add_position_encoding_op.py b/python/paddle/fluid/tests/unittests/test_add_position_encoding_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2a33793028f0883ffe94dd8a32626ad5c0351c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_add_position_encoding_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +import math +import paddle.fluid.core as core +from op_test import OpTest + + +class TestAddPositionEncodingTensorOp(OpTest): + """ + This class is to test the AddPositionEncodingOp + """ + + def setUp(self): + """ + the prepared section for add position encoding op + """ + self.op_type = "add_position_encoding" + self.dtype = np.float32 + self.init_input_output() + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x), } + self.outputs = {'Out': self.out} + self.attrs = {'alpha': self.alpha, 'beta': self.beta} + + def test_check_output(self): + """ + check the correctness of output + """ + self.check_output() + + def test_check_grad(self): + """ + check the correctness of grad + """ + self.check_grad(['X'], 'Out', max_relative_error=0.005) + + def init_input_output(self): + """ + init the input and output for test cases + """ + self.alpha = 0.6 + self.beta = 0.5 + self.x = np.random.uniform(0.1, 1, [2, 4, 4]).astype(self.dtype) + self.out = np.copy(self.x) + + batch_size = self.x.shape[0] + max_length = self.x.shape[1] + enc_size = self.x.shape[2] + + half_shape = int(enc_size / 2) + for i in range(batch_size): + for j in range(max_length): + for k in range(half_shape): + val = j / pow(10000.0, k / ( + half_shape - 1)) if half_shape > 1 else j / 10000.0 + self.out[i, j, k] = \ + self.x[i, j, k] * self.alpha + math.sin(val) * self.beta + self.out[i, j, half_shape + k] = \ + self.x[i, j, half_shape + k] * self.alpha + math.cos(val) * self.beta + + +class TestAddPositionEncodingLoDTensorOp(OpTest): + """ + This class is to test the AddPositionEncodingLoDTensorOp + """ + + def setUp(self): + """ + the prepared section for add position encoding LoDTensor op + """ + self.op_type = "add_position_encoding" + self.dtype = np.float32 + self.init_input_output() + + self.inputs = {'X': (self.x, self.lod), } + self.outputs = {'Out': (self.out, self.lod)} + self.attrs = {'alpha': self.alpha, 'beta': self.beta} + + def test_check_output(self): + """ + check the correctness of output + """ + self.check_output() + + def test_check_grad(self): + """ + check the correctness of grad + """ + self.check_grad(['X'], 'Out', max_relative_error=0.005) + + def init_input_output(self): + """ + init the input and output for test cases + """ + self.alpha = 0.6 + self.beta = 0.5 + self.x = np.random.uniform(0.1, 1, [10, 4]).astype(self.dtype) + self.lod = [[3, 7]] + self.out = np.copy(self.x) + + batch_size = len(self.lod[0]) + enc_size = self.x.shape[1] + + start = 0 + half_shape = int(enc_size / 2) + for i in range(batch_size): + max_length = self.lod[0][i] + for j in range(max_length): + for k in range(half_shape): + val = j / pow(10000.0, k / ( + half_shape - 1)) if half_shape > 1 else j / 10000.0 + pos = start + j + self.out[pos, k] = \ + self.x[pos, k] * self.alpha + math.sin(val) * self.beta + self.out[pos, half_shape + k] = \ + self.x[pos, half_shape + k] * self.alpha + math.cos(val) * self.beta + start += max_length + + +if __name__ == '__main__': + unittest.main()