sequence_softmax_op.cc 7.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
16

17
#include <string>
18 19 20 21 22 23 24 25

namespace paddle {
namespace operators {

class SequenceSoftmaxOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

26
  void InferShape(framework::InferShapeContext* ctx) const override {
27 28
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceSoftmax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceSoftmax");
29 30

    ctx->ShareDim("X", /*->*/ "Out");
31
    ctx->ShareLoD("X", /*->*/ "Out");
32
  }
33 34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    // choose cudnn kernel if the runtime supported.
38 39
    bool use_cudnn =
        ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
40
    bool runtime_cudnn_support = false;
41
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
42 43 44 45 46 47 48 49 50 51
    if (platform::is_gpu_place(ctx.GetPlace())) {
      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
    }
#endif
    framework::LibraryType library_ = framework::LibraryType::kPlain;
    if (use_cudnn && runtime_cudnn_support) {
      library_ = framework::LibraryType::kCUDNN;
    }
52 53 54
    std::string data_format = ctx.HasAttr("data_format")
                                  ? ctx.Attr<std::string>("data_format")
                                  : "AnyLayout";
55
    return framework::OpKernelType(
56 57 58 59
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.GetPlace(),
        framework::StringToDataLayout(data_format),
        library_);
60
  }
61 62 63 64
};

class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
65
  void Make() override {
66 67 68 69 70 71
    AddInput("X",
             "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
             "of length 1.");
    AddOutput("Out",
              "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
              "of length 1.");
72 73 74
    AddAttr<bool>(
        "use_cudnn",
        "(bool, default false) Only used in cudnn kernel, need install cudnn")
75 76
        .SetDefault(false)
        .AsExtra();
77 78 79 80 81 82
    AddAttr<std::string>(
        "data_format",
        "(string, default NCHW) Only used in "
        "An optional string from: \"NHWC\", \"NCHW\". "
        "Defaults to \"NHWC\". Specify the data format of the output data, "
        "the input will be transformed automatically. ")
83 84
        .SetDefault("AnyLayout")
        .AsExtra();
85
    AddComment(R"DOC(
86 87 88
Sequence Softmax Operator.

SequenceSoftmaxOp computes the softmax activation among all time-steps for each
89
sequence. The dimension of each time-step should be 1. Thus, the shape of
90 91
input Tensor can be either [N, 1] or [N], where N is the sum of the length
of all sequences.
92

93
The algorithm works as follows:
W
whs 已提交
94

95
    for i-th sequence in a mini-batch:
W
whs 已提交
96 97 98 99 100 101

$$
Out(X[lod[i]:lod[i+1]], :) = \
\frac{\exp(X[lod[i]:lod[i+1], :])} \
{\sum(\exp(X[lod[i]:lod[i+1], :]))}
$$
102 103 104

For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
105
then softmax will be computed among X[0:2, :], X[2:5, :], X[5:7, :]
106
and N turns out to be 7.
107

108 109 110 111 112 113 114 115
)DOC");
  }
};

class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

116
  void InferShape(framework::InferShapeContext* ctx) const override {
117
    OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "SequenceSoftmaxGrad");
118 119 120 121
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   "Out@GRAD",
                   "SequenceSoftmaxGrad");
122
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceSoftmaxGrad");
123 124 125 126
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
                   "Output",
                   "X@GRAD",
                   "SequenceSoftmaxGrad");
127 128 129

    auto out_dim = ctx->GetInputDim("Out");
    auto out_grad_dim = ctx->GetInputDim(framework::GradVarName("Out"));
130
    PADDLE_ENFORCE_EQ(
131 132
        out_dim,
        out_grad_dim,
133 134 135 136
        platform::errors::InvalidArgument(
            "The shape of Input(Out) and Input(Out@GRAD) of "
            "SequenceSoftmaxGrad operator do not match. The Input(Out)'s shape "
            "is [%s], the Input(Out@GRAD)'s shape is [%s].",
137 138
            out_dim,
            out_grad_dim));
139 140 141

    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
  }
142 143 144 145 146

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    // choose cudnn kernel if the runtime supported.
147 148
    bool use_cudnn =
        ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
149
    bool runtime_cudnn_support = false;
150
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
151 152 153 154 155 156 157 158 159 160
    if (platform::is_gpu_place(ctx.GetPlace())) {
      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
    }
#endif
    framework::LibraryType library_ = framework::LibraryType::kPlain;
    if (use_cudnn && runtime_cudnn_support) {
      library_ = framework::LibraryType::kCUDNN;
    }
161 162 163
    std::string data_format = ctx.HasAttr("data_format")
                                  ? ctx.Attr<std::string>("data_format")
                                  : "AnyLayout";
164
    return framework::OpKernelType(
165 166 167 168
        OperatorWithKernel::IndicateVarDataType(ctx, "Out"),
        ctx.GetPlace(),
        framework::StringToDataLayout(data_format),
        library_);
169
  }
170 171
};

172
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
173 174
    SequenceSoftmaxGradOpNoNeedBufferVarsInferer, "X");

175 176 177 178
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
179
REGISTER_OPERATOR(
180 181 182
    sequence_softmax,
    ops::SequenceSoftmaxOp,
    ops::SequenceSoftmaxOpMaker,
H
hong 已提交
183 184
    paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
    paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
185 186
REGISTER_OPERATOR(sequence_softmax_grad,
                  ops::SequenceSoftmaxGradOp,
187
                  ops::SequenceSoftmaxGradOpNoNeedBufferVarsInferer);
L
Leo Chen 已提交
188 189 190 191 192 193
REGISTER_OP_CPU_KERNEL(sequence_softmax,
                       ops::SequenceSoftmaxKernel<phi::CPUContext, float>,
                       ops::SequenceSoftmaxKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(sequence_softmax_grad,
                       ops::SequenceSoftmaxGradKernel<phi::CPUContext, float>,
                       ops::SequenceSoftmaxGradKernel<phi::CPUContext, double>);