linear_chain_crf_op.h 17.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
C
caoying03 已提交
19 20 21 22

namespace paddle {
namespace operators {

C
caoying03 已提交
23
template <typename T>
C
caoying03 已提交
24
static inline T NormalizeL1(T* x, size_t len) {
C
caoying03 已提交
25 26 27 28 29
  T sum = 0.;
  for (size_t i = 0; i < len; ++i) sum += x[i];
  // (This comment is from the old LinearChainCRFLayer.)
  // Right now, we just bet that sum won't be zero. If this really happens, we
  // will figure out what should be done then.
30 31 32 33
  PADDLE_ENFORCE_GT(
      sum, 0., platform::errors::InvalidArgument(
                   "The unnormalized probabilities of all possible unfinished "
                   "sequences must be greater than 0."));
C
caoying03 已提交
34 35 36 37 38
  T s = 1. / sum;
  for (size_t i = 0; i < len; ++i) x[i] *= s;
  return sum;
}

39 40 41 42 43 44 45 46
template <typename T>
struct ScalarMul {
  explicit ScalarMul(const T& scalar) : scalar(scalar) {}
  T operator()(const T& val) const { return val * scalar; }

  T scalar;
};

C
caoying03 已提交
47 48
using framework::LoDTensor;
using framework::LoD;
49
using framework::Tensor;
C
caoying03 已提交
50 51 52 53
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
54
template <typename DeviceContext, typename T>
C
caoying03 已提交
55
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
C
caoying03 已提交
56
 public:
C
caoying03 已提交
57
  void Compute(const framework::ExecutionContext& ctx) const override {
58 59 60
    const Tensor* emission_weights = ctx.Input<framework::Tensor>("Emission");
    const Tensor* transition_weights =
        ctx.Input<framework::Tensor>("Transition");
61 62 63 64 65

    Tensor* emission_exps = ctx.Output<Tensor>("EmissionExps");
    Tensor* transition_exps = ctx.Output<Tensor>("TransitionExps");
    Tensor* alpha = ctx.Output<Tensor>("Alpha");
    Tensor* ll = ctx.Output<Tensor>("LogLikelihood");
C
caoying03 已提交
66

67 68
    // Because the computation codes only runs on CPU, here the memory for all
    // the outputs is FIXED to be allocated on the CPU memory.
69 70
    emission_exps->mutable_data<T>(platform::CPUPlace());
    alpha->mutable_data<T>(platform::CPUPlace());
71
    transition_exps->mutable_data<T>(platform::CPUPlace());
72 73 74
    auto emission_dims = emission_weights->dims();

    const Tensor* label = ctx.Input<framework::Tensor>("Label");
75 76 77 78 79 80 81
    Tensor emission_weights_tmp = *emission_weights;
    Tensor label_tmp = *label;
    Tensor emission_exps_tmp = *emission_exps;
    Tensor alpha_tmp = *alpha;
    int64_t seq_num = 0;
    int64_t batch_size;
    int64_t tag_num;
82
    const int64_t* length_data = nullptr;
83 84 85
    framework::LoD in_lod;
    if (ctx.HasInput("Length")) {
      const Tensor* label_length = ctx.Input<framework::Tensor>("Length");
86 87
      length_data = label_length->data<int64_t>();
      seq_num = label_length->numel();
88 89 90 91 92 93
      PADDLE_ENFORCE_EQ(
          seq_num, emission_dims[0],
          platform::errors::InvalidArgument(
              "the size of Input(length) must be equal to "
              "emission_dims[0]. But input_size = %d, emission_dims[0] = %d.",
              seq_num, emission_dims[0]));
94
      auto label_dims = label->dims();
95 96 97 98 99 100
      PADDLE_ENFORCE_EQ(
          seq_num, label_dims[0],
          platform::errors::InvalidArgument(
              "the size of Input(length) must be equal to "
              "label_dims[0]. But input_size = %d, label_dims[0] = %d.",
              seq_num, label_dims[0]));
101 102 103 104 105 106 107 108 109

      batch_size = emission_dims[0] * emission_dims[1];
      tag_num = emission_dims[2];
      emission_weights_tmp.Resize({batch_size, tag_num});
      label_tmp.Resize({batch_size, 1});
      alpha_tmp.Resize({batch_size, tag_num});
      emission_exps_tmp.Resize({batch_size, tag_num});
      math::set_constant(ctx.device_context(), emission_exps, 0.0);
      math::set_constant(ctx.device_context(), alpha, 0.0);
110
    } else {
111
      in_lod = ctx.Input<LoDTensor>("Label")->lod();
112 113 114
      PADDLE_ENFORCE_NE(in_lod.size(), 0,
                        platform::errors::InvalidArgument(
                            "Input(Label) must be a sequence."));
115
      seq_num = in_lod[0].size() - 1;
116 117 118 119
      batch_size = emission_dims[0];
      tag_num = emission_dims[1];
    }

120 121
    // Resize the output tensor to its correct dimension.
    ll->Resize({seq_num, 1});
122 123
    ll->mutable_data<T>(platform::CPUPlace());
    // Now, all the inputs and outputs should be on the CPU memory.
C
caoying03 已提交
124 125
    Tensor emission_row_max;
    emission_row_max.mutable_data<T>(
C
Cao Ying 已提交
126
        framework::make_ddim({static_cast<int64_t>(batch_size), 1}),
127
        platform::CPUPlace());
Q
QI JUN 已提交
128 129
    auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
130
    auto x = EigenMatrix<T>::From(emission_weights_tmp);
C
caoying03 已提交
131 132 133
    auto x_row_max = EigenMatrix<T>::From(emission_row_max);
    x_row_max.device(place) =
        x.maximum(Eigen::DSizes<int, 1>(1))
134
            .reshape(Eigen::DSizes<int, 2>(static_cast<int>(batch_size), 1));
135
    auto x_exps = EigenMatrix<T>::From(emission_exps_tmp);
C
caoying03 已提交
136 137 138 139 140
    x_exps.device(place) =
        (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
    auto w = EigenMatrix<T>::From(*transition_weights);
    auto w_exps = EigenMatrix<T>::From(*transition_exps);
    w_exps.device(place) = w.exp();
141
    T* log_likelihood = ll->data<T>();
142 143 144 145
    for (int64_t i = 0; i < seq_num; ++i) {
      int64_t start_pos = 0;
      int64_t end_pos = 0;
      if (ctx.HasInput("Length")) {
146
        start_pos = i * emission_dims[1];
147
        end_pos = start_pos + length_data[i];
148
      } else {
149 150
        start_pos = static_cast<int64_t>(in_lod[0][i]);
        end_pos = static_cast<int64_t>(in_lod[0][i + 1]);
151
      }
C
caoying03 已提交
152 153 154 155 156
      if (end_pos == start_pos) {
        // If an empty input sequence is given, pad 0 for its cost.
        log_likelihood[i] = 0.;
        continue;
      }
157
      const Tensor one_seq = emission_weights_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
158
      Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos);
159 160 161
      Tensor one_seq_exps = emission_exps_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
      Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
162 163 164 165
      log_likelihood[i] = ForwardOneSequence(
          one_seq, one_seq_row_max, one_seq_exps, *transition_weights,
          *transition_exps, one_seq_label, &one_seq_alpha);
    }
166 167 168
  };

 private:
C
caoying03 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
  T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max,
                       const Tensor& emission_exps, const Tensor& trans_weights,
                       const Tensor& trans_weight_exps, const Tensor& label,
                       Tensor* alpha) const {
    const T* x = emission.data<T>();
    const T* x_row_max = emission_row_max.data<T>();
    const T* x_exps = emission_exps.data<T>();
    const T* w = trans_weights.data<T>();
    const T* w_exps = trans_weight_exps.data<T>();
    T* alpha_value = alpha->data<T>();

    auto x_dims = emission.dims();
    const size_t seq_length = x_dims[0];
    const size_t tag_num = x_dims[1];
    // The 1st row of w are transition weights for start mask.
    // The 2nd row of w are transition weights for end mask.
    // Transition weights between other tags begin from the 3rd row of w.
    const size_t state_trans_base_idx = 2;

    for (size_t i = 0; i < tag_num; ++i) {
      alpha_value[i] = w_exps[i] * x_exps[i];
    }
    T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num));

    for (size_t k = 1; k < seq_length; ++k) {
      for (size_t i = 0; i < tag_num; ++i) {
        T sum = 0.;
        for (size_t j = 0; j < tag_num; ++j) {
C
caoying03 已提交
197
          sum += alpha_value[(k - 1) * tag_num + j] *  // (*)
C
caoying03 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
                 w_exps[(j + state_trans_base_idx) * tag_num + i];
        }
        alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
      }
      // NormalizeL1 is to avoid underflow or overflow at (*).
      ll -= x_row_max[k] +
            std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
    }
    T sum = 0.;
    for (size_t i = 0; i < tag_num; ++i) {
      sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
    }
    ll -= std::log(sum);
    // Now ll is equal to -log(Z).

Q
Qiao Longfei 已提交
213
    const int64_t* lbl = label.data<int64_t>();
C
caoying03 已提交
214
    PADDLE_ENFORCE_LT(
C
Cao Ying 已提交
215
        static_cast<size_t>(*std::max_element(lbl, lbl + seq_length)), tag_num,
216 217
        platform::errors::InvalidArgument(
            "An invalid tag label that execesses the largest tag number."));
C
caoying03 已提交
218 219 220 221 222 223 224 225 226

    // Calculate the nominator part, which depends on the label sequence.
    ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] +
          w[tag_num + lbl[seq_length - 1]] /*end transition*/;
    for (size_t k = 1; k < seq_length; ++k) {
      ll += x[k * tag_num + lbl[k]] +
            w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]];
    }
    return -ll;
C
caoying03 已提交
227
  }
C
caoying03 已提交
228 229
};

Q
QI JUN 已提交
230
template <typename DeviceContext, typename T>
C
caoying03 已提交
231
class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
C
caoying03 已提交
232
 public:
C
caoying03 已提交
233
  void Compute(const framework::ExecutionContext& ctx) const override {
234
    const Tensor* label = ctx.Input<Tensor>("Label");
235 236 237 238 239 240 241
    const Tensor* emission_exps = ctx.Input<Tensor>("EmissionExps");
    const Tensor* transition_exps = ctx.Input<Tensor>("TransitionExps");
    const Tensor* alpha = ctx.Input<Tensor>("Alpha");
    const T* ll_grad =
        ctx.Input<Tensor>(framework::GradVarName("LogLikelihood"))->data<T>();
    Tensor* emission_grad =
        ctx.Output<Tensor>(framework::GradVarName("Emission"));
242 243 244
    auto* emission_grad_data =
        emission_grad->mutable_data<T>(platform::CPUPlace());
    memset(emission_grad_data, 0, emission_grad->numel() * sizeof(T));
245 246 247 248
    Tensor alpha_tmp = *alpha;
    Tensor label_tmp = *label;
    Tensor emission_exps_tmp = *emission_exps;
    Tensor emission_grad_tmp = *emission_grad;
249
    // getting seq_num  using padding or not
250 251
    int64_t seq_num = 0;
    framework::LoD in_lod;
252
    const int64_t* length_data = nullptr;
253 254
    if (ctx.HasInput("Length")) {
      const Tensor* label_length = ctx.Input<framework::Tensor>("Length");
255 256 257 258 259 260
      length_data = label_length->data<int64_t>();
      seq_num = label_length->numel();
      auto emission_dims = emission_grad->dims();
      auto label_dims = label->dims();
      emission_grad_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
261
      label_tmp.Resize({label_dims[0] * label_dims[1], 1});
262 263 264 265
      alpha_tmp.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
      emission_exps_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
    } else {
266
      in_lod = ctx.Input<LoDTensor>("Label")->lod();
267 268 269
      PADDLE_ENFORCE_NE(in_lod.size(), 0,
                        platform::errors::InvalidArgument(
                            "Input(Label) must be a sequence."));
270
      seq_num = static_cast<int64_t>(in_lod[0].size() - 1);
271 272
    }

273 274
    Tensor* transition_grad =
        ctx.Output<Tensor>(framework::GradVarName("Transition"));
C
caoying03 已提交
275 276 277

    // TODO(caoying) Fix this constraint. When the Input(Emission) is from the
    // data reader operator, it can have no gradients.
278 279
    if (transition_grad) {
      transition_grad->mutable_data<T>(platform::CPUPlace());
Q
QI JUN 已提交
280
      math::set_constant(ctx.device_context(), transition_grad, 0.);
C
caoying03 已提交
281
    }
282
    // Now, all the inputs and outputs should be on the CPU memory.
C
caoying03 已提交
283 284 285
    auto emission_dims = emission_exps->dims();
    // Beta is the memo table used in dynamic programming to calculate the
    // backwark vectors. For a backward vector i (the i-th row of beta), it
286 287
    // captures the unnormalized probabilities of partial sequences starting
    // at position i.
C
caoying03 已提交
288
    Tensor beta;
289 290
    beta.mutable_data<T>(emission_dims, platform::CPUPlace());
    if (ctx.HasInput("Length")) {
291 292
      beta.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
    }
293 294 295 296 297

    for (int64_t i = 0; i < seq_num; ++i) {
      int64_t start_pos = 0;
      int64_t end_pos = 0;
      if (ctx.HasInput("Length")) {
298
        start_pos = i * emission_dims[1];
299
        end_pos = start_pos + length_data[i];
300
      } else {
301 302 303 304 305 306
        start_pos = static_cast<int64_t>(in_lod[0][i]);
        end_pos = static_cast<int64_t>(in_lod[0][i + 1]);
      }

      if (end_pos == start_pos) {
        continue;
307
      }
C
caoying03 已提交
308
      const Tensor one_seq_emission_exps =
309 310 311
          emission_exps_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
312
      Tensor one_seq_beta = beta.Slice(start_pos, end_pos);
313 314
      Tensor one_seq_emission_grad =
          emission_grad_tmp.Slice(start_pos, end_pos);
Q
QI JUN 已提交
315 316 317 318
      BackwardOneSequence(
          ctx.template device_context<platform::CPUDeviceContext>(), ll_grad[i],
          one_seq_emission_exps, *transition_exps, one_seq_alpha, one_seq_label,
          &one_seq_beta, transition_grad, &one_seq_emission_grad);
319
    }
C
caoying03 已提交
320
  };
C
caoying03 已提交
321

322
 private:
Q
QI JUN 已提交
323 324
  void BackwardOneSequence(const platform::CPUDeviceContext& ctx,
                           const T ll_grad, const Tensor& emission_exps,
C
caoying03 已提交
325 326
                           const Tensor& transition_exps, const Tensor& alpha,
                           const Tensor& label, Tensor* beta,
C
caoying03 已提交
327
                           Tensor* transition_grad,
C
caoying03 已提交
328 329 330
                           Tensor* emission_grad) const {
    const T* w_exps = transition_exps.data<T>();
    const T* x_exps = emission_exps.data<T>();
Q
Qiao Longfei 已提交
331
    const int64_t* label_value = label.data<int64_t>();
C
caoying03 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    T* beta_value = beta->data<T>();
    auto x_dims = emission_exps.dims();
    const size_t seq_length = x_dims[0];
    const size_t tag_num = x_dims[1];
    const size_t state_trans_base_idx = 2;

    // Calculate the backward vectors: beta.
    // First, calculate the initialition state.
    for (size_t i = 0; i < tag_num; ++i) {
      beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
    }
    NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
    for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
      for (size_t i = 0; i < tag_num; ++i) {
        T sum = 0.;
        for (size_t j = 0; j < tag_num; ++j) {
C
caoying03 已提交
348
          sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *  // (**)
C
caoying03 已提交
349 350 351 352 353 354 355 356 357
                 x_exps[(k + 1) * tag_num + j] *
                 beta_value[(k + 1) * tag_num + j];
        }
        beta_value[k * tag_num + i] = sum;
      }
      // NormalizeL1 is to avoid underflow or overflow at (**).
      NormalizeL1<T>(beta_value + k * tag_num, tag_num);
    }

358
    auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
C
caoying03 已提交
359 360
    auto alpha_mat = EigenMatrix<T>::From(alpha);
    auto beta_mat = EigenMatrix<T>::From(*beta);
361

Q
QI JUN 已提交
362
    auto* place = ctx.eigen_device();
C
caoying03 已提交
363 364 365 366
    auto prob = alpha_mat * beta_mat;
    auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
                       .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
                       .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
367 368
    x_grad_mat.device(*place) =
        (prob / row_sum).unaryExpr(ScalarMul<T>(ll_grad));
C
caoying03 已提交
369 370

    for (size_t k = 0; k < seq_length; ++k) {
371
      x_grad_mat(k, label_value[k]) -= static_cast<T>(ll_grad);
C
caoying03 已提交
372 373 374 375 376
    }

    if (transition_grad) {
      T* trans_grad = transition_grad->data<T>();
      for (size_t k = 0; k < tag_num; ++k) {
377 378
        // Do not multiply by the output gradient here, because x_grad_mat has
        // alrealy done this.
C
caoying03 已提交
379 380 381 382 383 384 385
        trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
        trans_grad[tag_num + k] +=
            x_grad_mat(/*to end state*/ seq_length - 1, k);
      }

      auto x_exps_mat = EigenMatrix<T>::From(emission_exps);

386 387
      // TODO(caoying): Fix this to avoid using this local variable if we can
      // profile the training process.
C
caoying03 已提交
388
      Tensor tmp;
389
      tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
C
caoying03 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
      auto tmp_mat = EigenMatrix<T>::From(tmp);
      auto prob = beta_mat * x_exps_mat;
      auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
                         .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
                         .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
      tmp_mat.device(*place) = prob / row_sum;

      for (size_t k = 1; k < seq_length; ++k) {
        T sum = 0.;
        for (size_t i = 0; i < tag_num; ++i) {
          for (size_t j = 0; j < tag_num; ++j) {
            sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *  // (**)
                   alpha_mat(k - 1, i) * tmp_mat(k, j);
          }
        }
        sum = 1. / sum;
        for (size_t i = 0; i < tag_num; ++i) {
          for (size_t j = 0; j < tag_num; ++j) {
            trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
                sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
410
                alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad;
C
caoying03 已提交
411 412 413
          }
        }
        trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
414
                   label_value[k]] -= static_cast<T>(ll_grad);
C
caoying03 已提交
415 416
      }
    }
C
caoying03 已提交
417
  }
C
caoying03 已提交
418 419 420 421
};

}  // namespace operators
}  // namespace paddle