提交 a2657fea 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #4472 from Xreki/core_add_sequence_softmax_op

Add sequence softmax operator.
......@@ -101,8 +101,8 @@ set(DEPS_OPS
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy_function)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -2,16 +2,14 @@ if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax_function SRCS softmax.cc softmax.cu
DEPS operator)
nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
DEPS operator)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax_function SRCS softmax.cc DEPS operator)
cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
endif()
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/softmax.h"
......@@ -19,6 +19,7 @@ namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUPlace, float>;
template class SoftmaxGradFunctor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
......@@ -21,6 +21,7 @@ namespace operators {
namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>;
template class SoftmaxGradFunctor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
......@@ -68,6 +68,37 @@ class SoftmaxFunctor {
.broadcast(one_by_class));
}
};
template <typename Place, typename T>
class SoftmaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto dot = (softmax * softmax_grad)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
logits_grad.device(*context.GetEigenDevice<Place>()) =
(softmax_grad - dot) * softmax;
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/mul_op.h"
......@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel {
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X should be larger than "
"`mul_op`'s `x_num_col_dims`.");
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
"The rank of input tensor Y should be larger than "
"`mul_op`'s `y_num_col_dims`.");
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
"The input tensor X's rank of MulOp should be larger than "
"x_num_col_dims.");
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims.");
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
......@@ -24,9 +24,9 @@ class SequencePoolOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceAvgPoolOp should not be null.");
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceAvgPoolOp should not be null.");
"Output(Out) of SequencePoolOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_softmax_op.h"
namespace paddle {
namespace operators {
class SequenceSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1.");
AddOutput("Out",
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
"of length 1.");
AddComment(R"DOC(
SequenceSoftmaxOp computes softmax activation among all time-steps for each
sequence. The dimension of each time-step should be 1. Thus, the shape of
input Tensor can be either [N, 1] or [N], where N is the sum of all sequences'
lengths.
Equation:
for i-th sequence in a mini-batch:
Out(X[lod[i]:lod[i+1]], :) =
exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :]))
For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
then softmax will be computed among X[0:2, :], X[2:5, :], X[5:7, :]
and N turns out to be 7.
)DOC");
}
};
class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"),
"Input(Out) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of "
"the same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
ops::SequenceSoftmaxOpMaker, sequence_softmax_grad,
ops::SequenceSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_softmax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::GPUPlace, float>)
REGISTER_OP_GPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class SequenceSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
"The first dimension of Input(X) should be equal to the "
"sum of all sequences' lengths.");
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor x_i = x->Slice<T>(start_pos, end_pos);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxFunctor<Place, T>()(ctx.device_context(), &x_i, &out_i);
}
}
};
template <typename Place, typename T>
class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice<T>(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradFunctor<Place, T>()(ctx.device_context(), &out_i,
&out_grad_i, &x_grad_i);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -29,8 +29,8 @@ template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto X = context.Input<Tensor>("X");
auto Y = context.Output<Tensor>("Y");
auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y");
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
......@@ -43,29 +43,14 @@ template <typename Place, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0];
const int class_num = Y->dims()[1];
Eigen::DSizes<int, 1> along_class(1);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, class_num);
auto* Y = context.Input<Tensor>("Y");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto Y_eigen = EigenMatrix<T>::From(*Y);
auto dY_eigen = EigenMatrix<T>::From(*dY);
auto dX_eigen = EigenMatrix<T>::From(*dX);
auto place = context.GetEigenDevice<Place>();
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
auto dot = (Y_eigen * dY_eigen)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen;
math::SoftmaxGradFunctor<Place, T>()(context.device_context(), Y, dY, dX);
}
};
......
import unittest
import numpy as np
from op_test import OpTest
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
class TestSequenceSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "sequence_softmax"
x = np.random.uniform(0.1, 1, (11, 1)).astype("float32")
lod = [[0, 4, 5, 8, 11]]
out = np.zeros((11, 1)).astype("float32")
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
sub_x = sub_x.reshape(1, lod[0][i + 1] - lod[0][i])
sub_out = stable_softmax(sub_x)
out[lod[0][i]:lod[0][i + 1], :] = sub_out.reshape(
lod[0][i + 1] - lod[0][i], 1)
self.inputs = {"X": (x, lod)}
self.outputs = {"Out": out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册