From 03897f251dc40ae3ded98a84caa3b40fed164de9 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 28 Sep 2017 06:39:23 +0000 Subject: [PATCH] Finish the SequenceSoftmaxGradKernel, using SoftmaxGradFunctor. --- paddle/operators/mul_op.cc | 32 ++++---- paddle/operators/sequence_softmax_op.cc | 79 ++++++++++++------- paddle/operators/sequence_softmax_op.h | 53 ++++++++++--- .../tests/test_sequence_softmax_op.py | 5 +- 4 files changed, 111 insertions(+), 58 deletions(-) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 9858c4d9c2..3c8fe04d2e 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/mul_op.h" @@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel { int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); - PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, - "The rank of input tensor X should be larger than " - "`mul_op`'s `x_num_col_dims`."); - PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, - "The rank of input tensor Y should be larger than " - "`mul_op`'s `y_num_col_dims`."); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + "The input tensor X's rank of MulOp should be larger than " + "x_num_col_dims."); + PADDLE_ENFORCE_GT( + y_dims.size(), y_num_col_dims, + "The input tensor Y's rank of MulOp should be larger than " + "y_num_col_dims."); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index 58ef77b1a3..e85b587a94 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -22,41 +22,42 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), "Input(X) of SequenceSoftmaxOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SequenceSoftmaxOp should not be null."); - - auto *x = ctx.Input("X"); - auto lod = x->lod(); - auto dims = x->dims(); - PADDLE_ENFORCE_GE( - dims[0], - /* batch_size */ static_cast(lod[0].size() - 1), - "The first dimension of Input(X) should be larger than batch size."); - - const size_t level = lod.size() - 1; - PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[level].back()), - "The width of each timestep in Input(X) of " - "SequenceSoftmaxOp should be 1."); - - std::cout << DebugString() << std::endl; - - ctx.Output("Out")->Resize({dims}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceSoftmaxOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequenceSoftmaxOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SequenceSoftmaxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor)"); - AddOutput("Out", "(LoDTensor)"); + AddInput("X", + "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension " + "of length 1."); + AddOutput("Out", + "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension " + "of length 1."); AddComment(R"DOC( -Softmax of Sequence. +SequenceSoftmaxOp computes softmax activation among all time-steps for each +sequences. The dimension of each time-step should be 1. Thus, the shape of +input Tensor can be either [N, 1] or [N], where N is the sum of all sequences' +length. + +Equation: + for i-th sequence in mini-batch: + Out(X[lod[i]:lod[i+1]], :) = + exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :])) + +For example, for a mini-batch of 3 sequences with variable-length, +each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7], +then softmax will be computed among X[0:2, :], X[2:5, :], X[2:7, :] +and N turns out to be 7. )DOC"); } }; @@ -66,7 +67,25 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input(Out) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) of SequenceSoftmaxOp should not be null."); + + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + "Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of " + "the same shape."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } }; } // namespace operators @@ -81,4 +100,4 @@ REGISTER_OP_CPU_KERNEL( ops::SequenceSoftmaxKernel); REGISTER_OP_CPU_KERNEL( sequence_softmax_grad, - ops::SequenceSoftmaxGradKernel); + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h index f39c2ec6c3..ca5cef4fc6 100644 --- a/paddle/operators/sequence_softmax_op.h +++ b/paddle/operators/sequence_softmax_op.h @@ -16,19 +16,13 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenVector = framework::EigenVector; -template -using EigenMatrix = framework::EigenMatrix; template class SequenceSoftmaxKernel : public framework::OpKernel { @@ -38,7 +32,17 @@ class SequenceSoftmaxKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto lod = x->lod(); + auto dims = x->dims(); + + PADDLE_ENFORCE_GE( + dims[0], + /* batch_size */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) should be larger than batch size."); + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[level].back()), + "The width of each timestep in Input(X) of " + "SequenceSoftmaxOp should be 1."); out->mutable_data(ctx.GetPlace()); for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { @@ -48,10 +52,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel { Tensor out_i = out->Slice(start_pos, end_pos); // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) - framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos}); - x_i.Resize(dims); - out_i.Resize(dims); - math::SoftmaxFunctor()(&x_i, &out_i, ctx); + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + x_i.Resize(dims_i); + out_i.Resize(dims_i); + math::SoftmaxFunctor()(ctx, &x_i, &out_i); } } }; @@ -59,7 +63,32 @@ class SequenceSoftmaxKernel : public framework::OpKernel { template class SequenceSoftmaxGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + + auto lod = x->lod(); + const size_t level = lod.size() - 1; + + x_grad->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + + Tensor out_i = out->Slice(start_pos, end_pos); + Tensor out_grad_i = out_grad->Slice(start_pos, end_pos); + Tensor x_grad_i = x_grad->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + out_i.Resize(dims_i); + out_grad_i.Resize(dims_i); + x_grad_i.Resize(dims_i); + math::SoftmaxGradFunctor()(ctx, &out_i, &out_grad_i, &x_grad_i); + } + } }; } // namespace operators diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py index d0667c1308..b54a56aa6d 100644 --- a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py @@ -5,7 +5,7 @@ from op_test import OpTest def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x) + shiftx = x - np.max(x).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) @@ -30,6 +30,9 @@ class TestSequenceSoftmaxOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(["X"], "Out", max_relative_error=0.01) + if __name__ == "__main__": unittest.main() -- GitLab