linear_chain_crf_op.cc 17.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/linear_chain_crf_op.h"
C
caoying03 已提交
16

X
xuezhong 已提交
17 18
#include <memory>

C
caoying03 已提交
19 20 21
namespace paddle {
namespace operators {

C
caoying03 已提交
22
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
C
caoying03 已提交
23
 public:
Y
Yu Yang 已提交
24
  void Make() override {
25
    AddInput("Emission",
26 27
             "(phi::DenseTensor<float>). When a phi::DenseTensor "
             "input,A 2-D phi::DenseTensor"
28 29 30 31 32
             " with shape [N x D], where N is the size of the "
             "mini-batch and D is the total tag number. The unscaled emission "
             "weight matrix for the linear chain CRF. When a Tensor input,"
             "A Tensor with shape [N x S x D], where N is batch number,"
             "S is max length of sequences, D is the total tag number."
33
             "A phi::DenseTensor with type float32, float64.");
C
Cao Ying 已提交
34
    AddInput("Transition",
K
kexinzhao 已提交
35
             "(Tensor, default Tensor<float>) A 2-D Tensor with shape "
C
Cao Ying 已提交
36 37 38
             "[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
             "operator. See more details in the operator's comments.");
    AddInput("Label",
39
             "(phi::DenseTensor<int64_t>), when a phi::DenseTensor input,  "
C
Cao Ying 已提交
40
             "[N x 1], where N is the total element number in a mini-batch. "
41
             "when a Tensor input, [N x S], where N is batch number. "
42
             "S is max length of sequences. The ground truth."
43
             "A  phi::DenseTensor with int64.");
44
    AddInput("Length",
45
             "(Tensor, default Tensor<int64_t>) A Tensor with shape "
46 47
             "[M x 1], where M is the sequence number in a mini-batch."
             "A Tensor with type int64.")
48
        .AsDispensable();
C
caoying03 已提交
49 50
    AddOutput(
        "Alpha",
51
        "(Tensor, default Tensor<float>), the same shape with Emission. "
52 53 54
        "The forward vectors for the entire batch. Denote it as $\alpha$. "
        "$\alpha$ is a memo table used to calculate the normalization "
        "factor in CRF. $\alpha[k, v]$ stores the unnormalized "
S
Shuangchi He 已提交
55 56
        "probabilities of all possible unfinished sequences of tags that end "
        "at position $k$ with tag $v$. For each $k$, "
57 58
        "$\alpha[k, v]$ is a vector of length $D$ with a component for "
        "each tag value $v$. This vector is called a forward vecotr and "
C
caoying03 已提交
59 60
        "will also be used in backward computations.")
        .AsIntermediate();
C
Cao Ying 已提交
61 62
    AddOutput(
        "EmissionExps",
63
        "(Tensor, default Tensor<float>), the same shape with Emission. "
C
Cao Ying 已提交
64 65
        "The exponentials of Input(Emission). This is an intermediate "
        "computational result in forward computation, and will be reused in "
66
        "backward computation."
67
        "A phi::DenseTensor with type float32, float64.")
C
caoying03 已提交
68
        .AsIntermediate();
C
Cao Ying 已提交
69 70
    AddOutput(
        "TransitionExps",
K
kexinzhao 已提交
71
        "(Tensor, default Tensor<float>) A 2-D Tensor with shape "
C
Cao Ying 已提交
72 73
        "[(D + 2) x D]. The exponentials of Input(Transition). This is an "
        "intermediate computational result in forward computation, and "
74
        "will be reused in backward computation."
75
        "A phi::DenseTensor with type float32, float64.")
C
caoying03 已提交
76
        .AsIntermediate();
C
caoying03 已提交
77 78
    AddOutput(
        "LogLikelihood",
K
kexinzhao 已提交
79
        "(Tensor, default Tensor<float>) The logarithm of the conditional "
C
caoying03 已提交
80 81
        "likelihood of each training sample in a mini-batch. This is a 2-D "
        "tensor with shape [S x 1], where S is the sequence number in a "
C
caoying03 已提交
82
        "mini-batch. Note: S is equal to the sequence number in a mini-batch. "
83
        "A Tensor with type float32, float64.");
C
caoying03 已提交
84 85 86
    AddComment(R"DOC(
Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these
87 88 89
variables. CRF learns the conditional probability $P(Y|X)$, where
$X = (x_1, x_2, ... , x_n)$ are structured inputs and
$Y = (y_1, y_2, ... , y_n)$ are labels for the inputs.
C
caoying03 已提交
90 91 92

Linear chain CRF is a special case of CRF that is useful for sequence labeling
task. Sequence labeling tasks do not assume a lot of conditional
C
caoying03 已提交
93 94 95
independences among inputs. The only constraint they impose is that the input
and output must be linear sequences. Thus, the graph of such a CRF is a simple
chain or a line, which results in the linear chain CRF.
C
caoying03 已提交
96

C
caoying03 已提交
97
This operator implements the Forward-Backward algorithm for the linear chain
K
kexinzhao 已提交
98 99
CRF. Please refer to http://www.cs.columbia.edu/~mcollins/fb.pdf and
http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for details.
C
caoying03 已提交
100 101

Equation:
Y
yi.wu 已提交
102

103
1. Denote Input(Emission) to this operator as $x$ here.
K
kexinzhao 已提交
104
2. The first D values of Input(Transition) to this operator are for starting
105
weights, denoted as $a$ here.
K
kexinzhao 已提交
106
3. The next D values of Input(Transition) of this operator are for ending
107
weights, denoted as $b$ here.
K
kexinzhao 已提交
108
4. The remaning values of Input(Transition) are for transition weights,
109 110
denoted as $w$ here.
5. Denote Input(Label) as $s$ here.
C
caoying03 已提交
111

112 113 114 115 116 117 118
The probability of a sequence $s$ of length $L$ is defined as:
$$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L}
                + \sum_{l=1}^L x_{s_l}
                + \sum_{l=2}^L w_{s_{l-1},s_l})$$

where $Z$ is a normalization value so that the sum of $P(s)$ over
all possible sequences is 1, and $x$ is the emission feature weight
C
caoying03 已提交
119 120
to the linear chain CRF.

K
kexinzhao 已提交
121
Finally, the linear chain CRF operator outputs the logarithm of the conditional
C
caoying03 已提交
122 123 124
likelihood of each training sample in a mini-batch.

NOTE:
Y
yi.wu 已提交
125

C
caoying03 已提交
126 127 128 129
1. The feature function for a CRF is made up of the emission features and the
transition features. The emission feature weights are NOT computed in
this operator. They MUST be computed first before this operator is called.

C
caoying03 已提交
130
2. Because this operator performs global normalization over all possible
C
caoying03 已提交
131 132 133 134
sequences internally, it expects UNSCALED emission feature weights.
Please do not call this op with the emission feature being output of any
nonlinear activation.

135
3. The 2nd dimension of Input(Emission) MUST be equal to the tag number.
C
caoying03 已提交
136 137 138 139 140

)DOC");
  }
};

C
caoying03 已提交
141
class LinearChainCRFOp : public framework::OperatorWithKernel {
C
caoying03 已提交
142 143 144
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

C
caoying03 已提交
145
  void InferShape(framework::InferShapeContext* ctx) const override {
146 147 148 149
    OP_INOUT_CHECK(
        ctx->HasInput("Emission"), "Input", "Emission", "LinearChainCRF");
    OP_INOUT_CHECK(
        ctx->HasInput("Transition"), "Input", "Transition", "LinearChainCRF");
150 151
    OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "LinearChainCRF");

152 153 154 155 156
    OP_INOUT_CHECK(
        ctx->HasOutput("Alpha"), "Output", "Alpha", "LinearChainCRF");
    OP_INOUT_CHECK(ctx->HasOutput("EmissionExps"),
                   "Output",
                   "EmissionExps",
157
                   "LinearChainCRF");
158 159 160
    OP_INOUT_CHECK(ctx->HasOutput("TransitionExps"),
                   "Output",
                   "TransitionExps",
161
                   "LinearChainCRF");
162 163 164
    OP_INOUT_CHECK(ctx->HasOutput("LogLikelihood"),
                   "Output",
                   "LogLikelihood",
165
                   "LinearChainCRF");
C
caoying03 已提交
166

C
caoying03 已提交
167
    auto transition_dims = ctx->GetInputDim("Transition");
168 169
    PADDLE_ENFORCE_EQ(transition_dims.size(),
                      2UL,
170 171 172
                      platform::errors::InvalidArgument(
                          "The Input(Transition) should be a 2-D tensor. But "
                          "received: input rank %u, input shape [%s].",
173 174
                          transition_dims.size(),
                          transition_dims));
X
xuezhong 已提交
175 176 177 178 179 180 181
    bool check = true;
    if ((!ctx->IsRuntime()) &&
        (transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
      check = false;
    }
    if (check) {
      PADDLE_ENFORCE_EQ(
182 183
          transition_dims[0] - 2,
          transition_dims[1],
184 185 186 187 188
          platform::errors::InvalidArgument(
              "An invalid dimension for the Input(Transition), which should "
              "be a 2-D tensor with shape [(D + 2) x D]. But received: input "
              "rank %u, "
              "input shape [%s].",
189 190
              transition_dims.size(),
              transition_dims));
X
xuezhong 已提交
191
    }
192
    auto emission_dims = ctx->GetInputDim("Emission");
193
    if (ctx->HasInput("Length")) {
194 195
      PADDLE_ENFORCE_EQ(emission_dims.size(),
                        3,
196 197 198
                        platform::errors::InvalidArgument(
                            "The Input(Emission) should be a 3-D tensor. But "
                            "received: input rank %u, input shape [%s].",
199 200
                            emission_dims.size(),
                            emission_dims));
201
      auto label_dims = ctx->GetInputDim("Label");
202 203 204 205
      PADDLE_ENFORCE_EQ(
          (label_dims.size() == 3UL && label_dims[2] == 1) ||
              (label_dims.size() == 2UL),
          true,
206 207 208 209
          platform::errors::InvalidArgument(
              "The Input(Label) should be a 3-D tensor with last dimension "
              "fixed to 1 or a 2-D tensor in padding mode. But received: input "
              "rank %u, input shape [%s].",
210 211
              label_dims.size(),
              label_dims));
212
      if (ctx->IsRuntime()) {
213 214
        PADDLE_ENFORCE_EQ(emission_dims[0],
                          label_dims[0],
215 216 217 218 219 220
                          platform::errors::InvalidArgument(
                              "The batch size of Input(Emission) "
                              "and Input(Label) should be the same. But "
                              "received Input(Emission): "
                              "rank %u, shape [%s]; received Input(Label): "
                              "rank %u, shape [%s].",
221 222 223 224 225 226
                              emission_dims.size(),
                              emission_dims,
                              label_dims.size(),
                              label_dims));
        PADDLE_ENFORCE_EQ(emission_dims[1],
                          label_dims[1],
227 228 229 230 231 232
                          platform::errors::InvalidArgument(
                              "The max length of Input(Emission) "
                              "and Input(Label) should be the same. But "
                              "received Input(Emission): "
                              "rank %u, shape [%s]; received Input(Label): "
                              "rank %u, shape [%s].",
233 234 235 236
                              emission_dims.size(),
                              emission_dims,
                              label_dims.size(),
                              label_dims));
237
      }
238
    } else {
239
      PADDLE_ENFORCE_EQ(
240 241
          emission_dims.size(),
          2,
242 243 244
          platform::errors::InvalidArgument(
              "The Input(Emission) should be a 2-D tensor. But received: "
              "input rank %u, input shape [%s].",
245 246
              emission_dims.size(),
              emission_dims));
247
      if (ctx->IsRuntime()) {
248 249
        PADDLE_ENFORCE_EQ(emission_dims[1],
                          transition_dims[1],
250 251 252 253 254 255 256
                          platform::errors::InvalidArgument(
                              "The 2nd dimension of the Input(Emission) and "
                              "the Input(Transition) "
                              "should be equal to the tag number. But received "
                              "Input(Emission): rank "
                              "%u, shape [%s]; received Input(Transition): "
                              "rank %u, shape [%s].",
257 258 259 260
                              emission_dims.size(),
                              emission_dims,
                              transition_dims.size(),
                              transition_dims));
261
      }
262 263

      auto label_dims = ctx->GetInputDim("Label");
264
      PADDLE_ENFORCE_EQ(
265 266
          label_dims.size(),
          2,
267 268 269 270
          platform::errors::InvalidArgument(
              "The Input(Label) should be a 2-D tensor with the 2nd "
              "dimensions fixed to 1. But received: input rank %u, "
              "input shape [%s].",
271 272
              label_dims.size(),
              label_dims));
273 274
      if (ctx->IsRuntime()) {
        PADDLE_ENFORCE_EQ(
275 276
            emission_dims[0],
            label_dims[0],
277 278 279 280 281
            platform::errors::InvalidArgument(
                "The first dimension of Input(Emission) and Input(Label) "
                "should be the same. But received Input(Emission): rank %u, "
                "shape "
                "[%s]; received Input(Label): rank %u, shape [%s].",
282 283 284
                emission_dims.size(),
                emission_dims,
                label_dims.size(),
285
                label_dims));
286
      }
287
    }
C
caoying03 已提交
288
    ctx->SetOutputDim("Alpha", emission_dims);
C
caoying03 已提交
289 290
    ctx->SetOutputDim("EmissionExps", emission_dims);
    ctx->SetOutputDim("TransitionExps", transition_dims);
C
caoying03 已提交
291
    // TODO(caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
292
    // is the sequence number in a mini-batch. The dimension set here should be
C
caoying03 已提交
293 294
    // resized to its correct size in the function Compute. Fix this once we can
    // get LoD information in the InferShape interface.
C
caoying03 已提交
295 296 297
    ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
  }

C
caoying03 已提交
298
 protected:
C
Cao Ying 已提交
299 300
  // Explicitly set that the data type of computation kernel of linear_chain_crf
  // is determined by its input "Emission".
301
  phi::KernelKey GetExpectedKernelType(
C
caoying03 已提交
302
      const framework::ExecutionContext& ctx) const override {
303
    return phi::KernelKey(
304 305
        OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
        platform::CPUPlace());
C
caoying03 已提交
306
  }
C
caoying03 已提交
307 308
};

C
caoying03 已提交
309
class LinearChainCRFGradOp : public framework::OperatorWithKernel {
C
caoying03 已提交
310 311 312
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

C
caoying03 已提交
313
  void InferShape(framework::InferShapeContext* ctx) const override {
314 315 316
    OP_INOUT_CHECK(ctx->HasInput("EmissionExps"),
                   "Input",
                   "EmissionExps",
317
                   "LinearChainCRFGrad");
318 319 320
    OP_INOUT_CHECK(ctx->HasInput("TransitionExps"),
                   "Input",
                   "TransitionExps",
321 322
                   "LinearChainCRFGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("LogLikelihood")),
323 324
                   "Input",
                   framework::GradVarName("LogLikelihood"),
325
                   "LinearChainCRFGrad");
C
caoying03 已提交
326

327
    auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
328
    auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
C
caoying03 已提交
329 330
    if (ctx->HasOutput(framework::GradVarName("Emission"))) {
      ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
331
      if (ctx->HasInput("Length") == false) {
332 333
        ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
      }
C
caoying03 已提交
334
    }
335

C
caoying03 已提交
336 337 338
    if (ctx->HasOutput(framework::GradVarName("Transition"))) {
      ctx->SetOutputDim(framework::GradVarName("Transition"),
                        transition_exps_dims);
S
sneaxiy 已提交
339
      ctx->ShareLoD("Transition", framework::GradVarName("Transition"));
C
caoying03 已提交
340
    }
C
caoying03 已提交
341
  }
C
caoying03 已提交
342 343 344

 protected:
  // Explicitly set that the data type of output of the linear_chain_crf_grad
C
caoying03 已提交
345
  // operator is determined by its input: gradients of LogLikelihood.
346
  phi::KernelKey GetExpectedKernelType(
C
caoying03 已提交
347
      const framework::ExecutionContext& ctx) const override {
348 349 350
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
                              ctx, framework::GradVarName("LogLikelihood")),
                          platform::CPUPlace());
C
caoying03 已提交
351
  }
C
caoying03 已提交
352 353
};

H
hong 已提交
354 355
template <typename T>
class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
356
 public:
H
hong 已提交
357
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
358 359

 protected:
360
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
361
    op->SetType("linear_chain_crf_grad");
H
hong 已提交
362 363 364 365 366 367 368 369 370
    op->SetAttrMap(this->Attrs());
    op->SetInput("Emission", this->Input("Emission"));
    op->SetInput("Transition", this->Input("Transition"));
    op->SetInput("Label", this->Input("Label"));
    op->SetInput("Alpha", this->Output("Alpha"));
    op->SetInput("EmissionExps", this->Output("EmissionExps"));
    op->SetInput("TransitionExps", this->Output("TransitionExps"));
    if (this->HasInput("Length")) {
      op->SetInput("Length", this->Input("Length"));
371
    }
S
sneaxiy 已提交
372
    op->SetInput(framework::GradVarName("LogLikelihood"),
H
hong 已提交
373
                 this->OutputGrad("LogLikelihood"));
S
sneaxiy 已提交
374

H
hong 已提交
375 376
    op->SetOutput(framework::GradVarName("Emission"),
                  this->InputGrad("Emission"));
S
sneaxiy 已提交
377
    op->SetOutput(framework::GradVarName("Transition"),
H
hong 已提交
378
                  this->InputGrad("Transition"));
S
sneaxiy 已提交
379 380 381
  }
};

382
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInferer,
383 384
                                    "Transition",
                                    "Emission");
S
sneaxiy 已提交
385

C
caoying03 已提交
386 387 388 389
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
390 391
REGISTER_OPERATOR(linear_chain_crf,
                  ops::LinearChainCRFOp,
H
hong 已提交
392 393 394
                  ops::LinearChainCRFOpMaker,
                  ops::LinearChainCRFGradMaker<paddle::framework::OpDesc>,
                  ops::LinearChainCRFGradMaker<paddle::imperative::OpBase>);
395 396
REGISTER_OPERATOR(linear_chain_crf_grad,
                  ops::LinearChainCRFGradOp,
397
                  ops::LinearChainCRFGradNoNeedBufferVarsInferer);
L
Leo Chen 已提交
398 399 400
REGISTER_OP_CPU_KERNEL(linear_chain_crf,
                       ops::LinearChainCRFOpKernel<phi::CPUContext, float>,
                       ops::LinearChainCRFOpKernel<phi::CPUContext, double>);
401 402
REGISTER_OP_CPU_KERNEL(
    linear_chain_crf_grad,
L
Leo Chen 已提交
403 404
    ops::LinearChainCRFGradOpKernel<phi::CPUContext, float>,
    ops::LinearChainCRFGradOpKernel<phi::CPUContext, double>);