warpctc_op.cc 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/warpctc_op.h"
Y
Yiqun Liu 已提交
16

W
Wu Yi 已提交
17 18 19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif

Y
Yiqun Liu 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
namespace paddle {
namespace operators {

class WarpCTCOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("Logits"),
                   "Input(Logits) of WarpCTCOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Label"),
                   "Input(Label) of WarpCTCOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"),
                   "Output(WarpCTCGrad) of WarpCTCOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Loss"),
                   "Output(Loss) of WarpCTCOp should not be null.");

    auto logits_dims = ctx->GetInputDim("Logits");
    int sequence_width =
        static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
    int blank = ctx->Attrs().Get<int>("blank");
    PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
                   "The value of Attr(blank) should be in interval [0, %d).",
                   sequence_width);
    // TODO(liuyiqun): it is tricky to set the wrong dimension here.
    ctx->SetOutputDim("Loss", {logits_dims[0], 1});
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
W
Wu Yi 已提交
52 53 54 55 56 57 58
    framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
    if (platform::CanCUDNNBeUsed(ctx)) {
      library_ = framework::LibraryType::kCUDNN;
    }
#endif
    framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
M
minqiyang 已提交
59 60
    return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
                                   ctx.device_context(), layout_, library_);
Y
Yiqun Liu 已提交
61 62 63 64 65
  }
};

class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
66
  void Make() override {
Y
Yiqun Liu 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    AddInput("Logits",
             "(LodTensor, default: LoDTensor<float>), the unscaled "
             "probabilities of variable-length sequences, which is a 2-D "
             "Tensor with LoD information. It's shape is "
             "[Lp, num_classes + 1], where Lp is the sum of all input "
             "sequences' length and num_classes is the true number of classes "
             "(not including the blank label).");
    AddInput("Label",
             "(LodTensor, default: LoDTensor<int>), the ground truth "
             "of variable-length sequence, which is a 2-D Tensor with LoD "
             "information. It is of the shape [Lg, 1], where Lg is th sum of "
             "all labels' length.");
    AddOutput("WarpCTCGrad",
              "(Tensor, default: Tensor<float>), a temporary "
              "output Tensor to store the gradients of warp-ctc, which is "
              "computed with loss together in one call. It is a 3-D Tensor of "
              "the shape [max_sequence_length, batch_size, num_classes + 1].")
        .AsIntermediate();
    AddOutput("Loss",
              "(Tensor, default: Tensor<float>), the Connectionist "
              "Temporal Classification (CTC) loss, which is a 2-D Tensor of "
              "the shape [batch_size, 1]");
    AddAttr<int>("blank",
                 "(int, default: 0), the blank label of Connectionist "
                 "Temporal Classification (CTC) loss, which is in the "
                 "half-opened interval [0, num_classes + 1).")
        .SetDefault(0);
    AddAttr<bool>("norm_by_times",
                  "(bool, default: false), whether to "
                  "normalize the gradients by the number of time-step, "
                  "which is also the sequence's length.")
        .SetDefault(false);
W
Wu Yi 已提交
99 100 101 102
    AddAttr<bool>("use_cudnn",
                  "(bool, default: false), whether to "
                  "use cudnn kernel.")
        .SetDefault(false);
Y
Yiqun Liu 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    AddComment(R"DOC(
An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
[Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin](
https://arxiv.org/pdf/1512.02595v1.pdf),
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with ctc, since a native softmax activation is
interated to the warp-ctc library, to to normlize values for each row of the
input tensor.

More detail of CTC loss can be found by refering to
[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with
Recurrent Neural Networks](
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf).
)DOC");
  }
};

class WarpCTCGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"),
                   "Input(WarpCTCGrad) of WarpCTCGradOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
                   "Output(Logits@GRAD) of WarpCTCGradOp should not be null.");
    ctx->SetOutputDim(framework::GradVarName("Logits"),
                      ctx->GetInputDim("Logits"));
    ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
M
minqiyang 已提交
138 139
    return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
                                   ctx.device_context());
Y
Yiqun Liu 已提交
140 141 142 143 144 145 146
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
147
REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
148 149
                  paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp);
Y
Yiqun Liu 已提交
150 151 152 153 154
REGISTER_OP_CPU_KERNEL(
    warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    warpctc_grad,
    ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>);