提交 03897f25 编写于 作者: L Liu Yiqun

Finish the SequenceSoftmaxGradKernel, using SoftmaxGradFunctor.

上级 05ed8ee8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/mul_op.h"
......@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel {
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X should be larger than "
"`mul_op`'s `x_num_col_dims`.");
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
"The rank of input tensor Y should be larger than "
"`mul_op`'s `y_num_col_dims`.");
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
"The input tensor X's rank of MulOp should be larger than "
"x_num_col_dims.");
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims.");
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
......@@ -22,41 +22,42 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("X"), "Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null.");
auto *x = ctx.Input<framework::LoDTensor>("X");
auto lod = x->lod();
auto dims = x->dims();
PADDLE_ENFORCE_GE(
dims[0],
/* batch_size */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) should be larger than batch size.");
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[level].back()),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
std::cout << DebugString() << std::endl;
ctx.Output<framework::LoDTensor>("Out")->Resize({dims});
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
SequenceSoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor)");
AddOutput("Out", "(LoDTensor)");
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1.");
AddOutput("Out",
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
"of length 1.");
AddComment(R"DOC(
Softmax of Sequence.
SequenceSoftmaxOp computes softmax activation among all time-steps for each
sequences. The dimension of each time-step should be 1. Thus, the shape of
input Tensor can be either [N, 1] or [N], where N is the sum of all sequences'
length.
Equation:
for i-th sequence in mini-batch:
Out(X[lod[i]:lod[i+1]], :) =
exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :]))
For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
then softmax will be computed among X[0:2, :], X[2:5, :], X[2:7, :]
and N turns out to be 7.
)DOC");
}
};
......@@ -66,7 +67,25 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"),
"Input(Out) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of "
"the same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
......@@ -81,4 +100,4 @@ REGISTER_OP_CPU_KERNEL(
ops::SequenceSoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::GPUPlace, float>);
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUPlace, float>);
......@@ -16,19 +16,13 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax_function.h"
#include "paddle/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class SequenceSoftmaxKernel : public framework::OpKernel {
......@@ -38,7 +32,17 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
PADDLE_ENFORCE_GE(
dims[0],
/* batch_size */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) should be larger than batch size.");
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(x->numel(), static_cast<int64_t>(lod[level].back()),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
......@@ -48,10 +52,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
Tensor out_i = out->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims);
out_i.Resize(dims);
math::SoftmaxFunctor<Place, T>()(&x_i, &out_i, ctx);
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxFunctor<Place, T>()(ctx, &x_i, &out_i);
}
}
};
......@@ -59,7 +63,32 @@ class SequenceSoftmaxKernel : public framework::OpKernel {
template <typename Place, typename T>
class SequenceSoftmaxGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice<T>(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradFunctor<Place, T>()(ctx, &out_i, &out_grad_i, &x_grad_i);
}
}
};
} // namespace operators
......
......@@ -5,7 +5,7 @@ from op_test import OpTest
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x)
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
......@@ -30,6 +30,9 @@ class TestSequenceSoftmaxOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册