linear_chain_crf_op.h 18.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
caoying03 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
C
caoying03 已提交
19 20 21 22

namespace paddle {
namespace operators {

C
caoying03 已提交
23
template <typename T>
C
caoying03 已提交
24
static inline T NormalizeL1(T* x, size_t len) {
C
caoying03 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37
  T sum = 0.;
  for (size_t i = 0; i < len; ++i) sum += x[i];
  // (This comment is from the old LinearChainCRFLayer.)
  // Right now, we just bet that sum won't be zero. If this really happens, we
  // will figure out what should be done then.
  PADDLE_ENFORCE(sum,
                 "The unnormalized probabilities of all possible unfinished "
                 "sequences must be greater than 0.");
  T s = 1. / sum;
  for (size_t i = 0; i < len; ++i) x[i] *= s;
  return sum;
}

38 39 40 41 42 43 44 45
template <typename T>
struct ScalarMul {
  explicit ScalarMul(const T& scalar) : scalar(scalar) {}
  T operator()(const T& val) const { return val * scalar; }

  T scalar;
};

C
caoying03 已提交
46 47
using framework::LoDTensor;
using framework::LoD;
48
using framework::Tensor;
C
caoying03 已提交
49 50 51 52
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

Q
QI JUN 已提交
53
template <typename DeviceContext, typename T>
C
caoying03 已提交
54
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
C
caoying03 已提交
55
 public:
C
caoying03 已提交
56
  void Compute(const framework::ExecutionContext& ctx) const override {
57 58 59
    const Tensor* emission_weights = ctx.Input<framework::Tensor>("Emission");
    const Tensor* transition_weights =
        ctx.Input<framework::Tensor>("Transition");
60 61 62 63 64

    Tensor* emission_exps = ctx.Output<Tensor>("EmissionExps");
    Tensor* transition_exps = ctx.Output<Tensor>("TransitionExps");
    Tensor* alpha = ctx.Output<Tensor>("Alpha");
    Tensor* ll = ctx.Output<Tensor>("LogLikelihood");
C
caoying03 已提交
65

66 67
    // Because the computation codes only runs on CPU, here the memory for all
    // the outputs is FIXED to be allocated on the CPU memory.
68 69 70
    auto* emission_exps_data =
        emission_exps->mutable_data<T>(platform::CPUPlace());
    auto* alpha_data = alpha->mutable_data<T>(platform::CPUPlace());
71 72
    transition_exps->mutable_data<T>(platform::CPUPlace());
    // Resize the output tensor to its correct dimension.
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    memset(emission_exps_data, 0, emission_exps->numel() * sizeof(T));
    memset(alpha_data, 0, alpha->numel() * sizeof(T));
    auto emission_dims = emission_weights->dims();

    const Tensor* label = ctx.Input<framework::Tensor>("Label");
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    Tensor emission_weights_tmp = ctx.AllocateTmpTensor<T, DeviceContext>(
        emission_weights->dims(), dev_ctx);
    emission_weights_tmp.ShareDataWith(*emission_weights);
    Tensor label_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(label->dims(), dev_ctx);
    label_tmp.ShareDataWith(*label);
    Tensor emission_exps_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(emission_exps->dims(), dev_ctx);
    emission_exps_tmp.ShareDataWith(*emission_exps);
    Tensor alpha_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(alpha->dims(), dev_ctx);
    alpha_tmp.ShareDataWith(*alpha);
    size_t seq_num = 0;
    size_t batch_size;
    size_t tag_num;
    const int64_t* length_data;
    framework::Vector<size_t> in_lod;
    if (ctx.HasInput("length")) {
      const Tensor* label_length = ctx.Input<framework::Tensor>("length");
      length_data = label_length->data<int64_t>();
      seq_num = label_length->numel();
      batch_size = emission_dims[0] * emission_dims[1];
      tag_num = emission_dims[2];
      emission_weights_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
      auto label_dims = label->dims();
      label_tmp.Resize({label_dims[0] * label_dims[1], label_dims[2]});
      alpha_tmp.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
      emission_exps_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
      PADDLE_ENFORCE_EQ(seq_num, emission_dims[0],
                        "the size of Input(length) must be equal to "
                        "emission_dims[0].");
      PADDLE_ENFORCE_EQ(seq_num, label_dims[0],
                        "the size of Input(length) must be equal to "
                        "label_dims[0].");
    } else {
      seq_num = ctx.Input<LoDTensor>("Label")->lod()[0].size() - 1;
      batch_size = emission_dims[0];
      tag_num = emission_dims[1];
      in_lod = ctx.Input<LoDTensor>("Label")->lod()[0];
      PADDLE_ENFORCE_NE(in_lod.size(), 0, "Input(Label) must be a sequence.");
    }

123 124 125
    ll->Resize({static_cast<int>(seq_num), 1});
    ll->mutable_data<T>(platform::CPUPlace());
    // Now, all the inputs and outputs should be on the CPU memory.
C
caoying03 已提交
126 127
    Tensor emission_row_max;
    emission_row_max.mutable_data<T>(
C
Cao Ying 已提交
128
        framework::make_ddim({static_cast<int64_t>(batch_size), 1}),
129
        platform::CPUPlace());
Q
QI JUN 已提交
130 131
    auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
132
    auto x = EigenMatrix<T>::From(emission_weights_tmp);
C
caoying03 已提交
133 134 135
    auto x_row_max = EigenMatrix<T>::From(emission_row_max);
    x_row_max.device(place) =
        x.maximum(Eigen::DSizes<int, 1>(1))
136
            .reshape(Eigen::DSizes<int, 2>(static_cast<int>(batch_size), 1));
137
    auto x_exps = EigenMatrix<T>::From(emission_exps_tmp);
C
caoying03 已提交
138 139 140 141 142
    x_exps.device(place) =
        (x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
    auto w = EigenMatrix<T>::From(*transition_weights);
    auto w_exps = EigenMatrix<T>::From(*transition_exps);
    w_exps.device(place) = w.exp();
143
    T* log_likelihood = ll->data<T>();
C
caoying03 已提交
144
    for (size_t i = 0; i < seq_num; ++i) {
145 146 147 148 149 150 151 152 153 154
      int start_pos = 0;
      int end_pos = 0;
      if (ctx.HasInput("length")) {
        if (length_data[i] == 0) continue;
        start_pos = i * emission_dims[1];
        end_pos = start_pos + static_cast<int>(length_data[i]);
      } else {
        start_pos = static_cast<int>(in_lod[i]);
        end_pos = static_cast<int>(in_lod[i + 1]);
      }
C
caoying03 已提交
155 156 157 158 159
      if (end_pos == start_pos) {
        // If an empty input sequence is given, pad 0 for its cost.
        log_likelihood[i] = 0.;
        continue;
      }
160
      const Tensor one_seq = emission_weights_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
161
      Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos);
162 163 164
      Tensor one_seq_exps = emission_exps_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
      Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
165 166 167 168
      log_likelihood[i] = ForwardOneSequence(
          one_seq, one_seq_row_max, one_seq_exps, *transition_weights,
          *transition_exps, one_seq_label, &one_seq_alpha);
    }
169 170 171
  };

 private:
C
caoying03 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
  T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max,
                       const Tensor& emission_exps, const Tensor& trans_weights,
                       const Tensor& trans_weight_exps, const Tensor& label,
                       Tensor* alpha) const {
    const T* x = emission.data<T>();
    const T* x_row_max = emission_row_max.data<T>();
    const T* x_exps = emission_exps.data<T>();
    const T* w = trans_weights.data<T>();
    const T* w_exps = trans_weight_exps.data<T>();
    T* alpha_value = alpha->data<T>();

    auto x_dims = emission.dims();
    const size_t seq_length = x_dims[0];
    const size_t tag_num = x_dims[1];
    // The 1st row of w are transition weights for start mask.
    // The 2nd row of w are transition weights for end mask.
    // Transition weights between other tags begin from the 3rd row of w.
    const size_t state_trans_base_idx = 2;

    for (size_t i = 0; i < tag_num; ++i) {
      alpha_value[i] = w_exps[i] * x_exps[i];
    }
    T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num));

    for (size_t k = 1; k < seq_length; ++k) {
      for (size_t i = 0; i < tag_num; ++i) {
        T sum = 0.;
        for (size_t j = 0; j < tag_num; ++j) {
C
caoying03 已提交
200
          sum += alpha_value[(k - 1) * tag_num + j] *  // (*)
C
caoying03 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
                 w_exps[(j + state_trans_base_idx) * tag_num + i];
        }
        alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
      }
      // NormalizeL1 is to avoid underflow or overflow at (*).
      ll -= x_row_max[k] +
            std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
    }
    T sum = 0.;
    for (size_t i = 0; i < tag_num; ++i) {
      sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
    }
    ll -= std::log(sum);
    // Now ll is equal to -log(Z).

Q
Qiao Longfei 已提交
216
    const int64_t* lbl = label.data<int64_t>();
C
caoying03 已提交
217
    PADDLE_ENFORCE_LT(
C
Cao Ying 已提交
218
        static_cast<size_t>(*std::max_element(lbl, lbl + seq_length)), tag_num,
C
caoying03 已提交
219 220 221 222 223 224 225 226 227 228
        "An invalid tag label that execesses the largest tag number.");

    // Calculate the nominator part, which depends on the label sequence.
    ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] +
          w[tag_num + lbl[seq_length - 1]] /*end transition*/;
    for (size_t k = 1; k < seq_length; ++k) {
      ll += x[k * tag_num + lbl[k]] +
            w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]];
    }
    return -ll;
C
caoying03 已提交
229
  }
C
caoying03 已提交
230 231
};

Q
QI JUN 已提交
232
template <typename DeviceContext, typename T>
C
caoying03 已提交
233
class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
C
caoying03 已提交
234
 public:
C
caoying03 已提交
235
  void Compute(const framework::ExecutionContext& ctx) const override {
236
    const Tensor* label = ctx.Input<Tensor>("Label");
237 238 239 240 241
    const Tensor* emission_exps = ctx.Input<Tensor>("EmissionExps");
    const Tensor* transition_exps = ctx.Input<Tensor>("TransitionExps");
    const Tensor* alpha = ctx.Input<Tensor>("Alpha");
    const T* ll_grad =
        ctx.Input<Tensor>(framework::GradVarName("LogLikelihood"))->data<T>();
242
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
243 244
    Tensor* emission_grad =
        ctx.Output<Tensor>(framework::GradVarName("Emission"));
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
    auto* emission_grad_data =
        emission_grad->mutable_data<T>(platform::CPUPlace());
    memset(emission_grad_data, 0, emission_grad->numel() * sizeof(T));
    Tensor alpha_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(alpha->dims(), dev_ctx);
    alpha_tmp.ShareDataWith(*alpha);
    Tensor label_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(label->dims(), dev_ctx);
    label_tmp.ShareDataWith(*label);
    Tensor emission_exps_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(emission_exps->dims(), dev_ctx);
    emission_exps_tmp.ShareDataWith(*emission_exps);
    Tensor emission_grad_tmp =
        ctx.AllocateTmpTensor<T, DeviceContext>(emission_grad->dims(), dev_ctx);
    emission_grad_tmp.ShareDataWith(*emission_grad);
    // getting seq_num  using padding or not
    size_t seq_num = 0;
    framework::Vector<size_t> lod;
    const int64_t* length_data;
    if (ctx.HasInput("length")) {
      const Tensor* label_length = ctx.Input<framework::Tensor>("length");
      length_data = label_length->data<int64_t>();
      seq_num = label_length->numel();
      auto emission_dims = emission_grad->dims();
      auto label_dims = label->dims();
      emission_grad_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
      label_tmp.Resize({label_dims[0] * label_dims[1], label_dims[2]});
      alpha_tmp.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
      emission_exps_tmp.Resize(
          {emission_dims[0] * emission_dims[1], emission_dims[2]});
    } else {
      seq_num = ctx.Input<LoDTensor>("Label")->lod()[0].size() - 1;
      lod = ctx.Input<LoDTensor>("Label")->lod()[0];
      PADDLE_ENFORCE_NE(lod.size(), 0, "Input(Label) must be a sequence.");
    }

282 283
    Tensor* transition_grad =
        ctx.Output<Tensor>(framework::GradVarName("Transition"));
C
caoying03 已提交
284 285 286

    // TODO(caoying) Fix this constraint. When the Input(Emission) is from the
    // data reader operator, it can have no gradients.
287 288
    if (transition_grad) {
      transition_grad->mutable_data<T>(platform::CPUPlace());
Q
QI JUN 已提交
289
      math::set_constant(ctx.device_context(), transition_grad, 0.);
C
caoying03 已提交
290
    }
291
    // Now, all the inputs and outputs should be on the CPU memory.
C
caoying03 已提交
292 293 294
    auto emission_dims = emission_exps->dims();
    // Beta is the memo table used in dynamic programming to calculate the
    // backwark vectors. For a backward vector i (the i-th row of beta), it
295 296
    // captures the unnormalized probabilities of partial sequences starting
    // at position i.
C
caoying03 已提交
297
    Tensor beta;
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    auto* beta_data = beta.mutable_data<T>(emission_dims, platform::CPUPlace());
    memset(beta_data, 0, beta.numel() * sizeof(T));
    if (ctx.HasInput("length")) {
      beta.Resize({emission_dims[0] * emission_dims[1], emission_dims[2]});
    }
    for (size_t i = 0; i < seq_num; ++i) {
      int start_pos = 0;
      int end_pos = 0;
      if (ctx.HasInput("length")) {
        if (length_data[i] == 0) continue;
        start_pos = i * emission_dims[1];
        end_pos = start_pos + static_cast<int>(length_data[i]);
      } else {
        start_pos = static_cast<int>(lod[i]);
        end_pos = static_cast<int>(lod[i + 1]);
      }
C
caoying03 已提交
314
      const Tensor one_seq_emission_exps =
315 316 317
          emission_exps_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_label = label_tmp.Slice(start_pos, end_pos);
      const Tensor one_seq_alpha = alpha_tmp.Slice(start_pos, end_pos);
C
caoying03 已提交
318
      Tensor one_seq_beta = beta.Slice(start_pos, end_pos);
319 320
      Tensor one_seq_emission_grad =
          emission_grad_tmp.Slice(start_pos, end_pos);
Q
QI JUN 已提交
321 322 323 324
      BackwardOneSequence(
          ctx.template device_context<platform::CPUDeviceContext>(), ll_grad[i],
          one_seq_emission_exps, *transition_exps, one_seq_alpha, one_seq_label,
          &one_seq_beta, transition_grad, &one_seq_emission_grad);
325
    }
C
caoying03 已提交
326
  };
C
caoying03 已提交
327

328
 private:
Q
QI JUN 已提交
329 330
  void BackwardOneSequence(const platform::CPUDeviceContext& ctx,
                           const T ll_grad, const Tensor& emission_exps,
C
caoying03 已提交
331 332
                           const Tensor& transition_exps, const Tensor& alpha,
                           const Tensor& label, Tensor* beta,
C
caoying03 已提交
333
                           Tensor* transition_grad,
C
caoying03 已提交
334 335 336
                           Tensor* emission_grad) const {
    const T* w_exps = transition_exps.data<T>();
    const T* x_exps = emission_exps.data<T>();
Q
Qiao Longfei 已提交
337
    const int64_t* label_value = label.data<int64_t>();
C
caoying03 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
    T* beta_value = beta->data<T>();
    auto x_dims = emission_exps.dims();
    const size_t seq_length = x_dims[0];
    const size_t tag_num = x_dims[1];
    const size_t state_trans_base_idx = 2;

    // Calculate the backward vectors: beta.
    // First, calculate the initialition state.
    for (size_t i = 0; i < tag_num; ++i) {
      beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
    }
    NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
    for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
      for (size_t i = 0; i < tag_num; ++i) {
        T sum = 0.;
        for (size_t j = 0; j < tag_num; ++j) {
C
caoying03 已提交
354
          sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *  // (**)
C
caoying03 已提交
355 356 357 358 359 360 361 362 363
                 x_exps[(k + 1) * tag_num + j] *
                 beta_value[(k + 1) * tag_num + j];
        }
        beta_value[k * tag_num + i] = sum;
      }
      // NormalizeL1 is to avoid underflow or overflow at (**).
      NormalizeL1<T>(beta_value + k * tag_num, tag_num);
    }

364
    auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
C
caoying03 已提交
365 366
    auto alpha_mat = EigenMatrix<T>::From(alpha);
    auto beta_mat = EigenMatrix<T>::From(*beta);
367

Q
QI JUN 已提交
368
    auto* place = ctx.eigen_device();
C
caoying03 已提交
369 370 371 372
    auto prob = alpha_mat * beta_mat;
    auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
                       .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
                       .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
373 374
    x_grad_mat.device(*place) =
        (prob / row_sum).unaryExpr(ScalarMul<T>(ll_grad));
C
caoying03 已提交
375 376

    for (size_t k = 0; k < seq_length; ++k) {
377
      x_grad_mat(k, label_value[k]) -= static_cast<T>(ll_grad);
C
caoying03 已提交
378 379 380 381 382
    }

    if (transition_grad) {
      T* trans_grad = transition_grad->data<T>();
      for (size_t k = 0; k < tag_num; ++k) {
383 384
        // Do not multiply by the output gradient here, because x_grad_mat has
        // alrealy done this.
C
caoying03 已提交
385 386 387 388 389 390 391
        trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
        trans_grad[tag_num + k] +=
            x_grad_mat(/*to end state*/ seq_length - 1, k);
      }

      auto x_exps_mat = EigenMatrix<T>::From(emission_exps);

392 393
      // TODO(caoying): Fix this to avoid using this local variable if we can
      // profile the training process.
C
caoying03 已提交
394
      Tensor tmp;
395
      tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
C
caoying03 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
      auto tmp_mat = EigenMatrix<T>::From(tmp);
      auto prob = beta_mat * x_exps_mat;
      auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
                         .reshape(Eigen::DSizes<int, 2>(seq_length, 1))
                         .broadcast(Eigen::DSizes<int, 2>(1, tag_num));
      tmp_mat.device(*place) = prob / row_sum;

      for (size_t k = 1; k < seq_length; ++k) {
        T sum = 0.;
        for (size_t i = 0; i < tag_num; ++i) {
          for (size_t j = 0; j < tag_num; ++j) {
            sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *  // (**)
                   alpha_mat(k - 1, i) * tmp_mat(k, j);
          }
        }
        sum = 1. / sum;
        for (size_t i = 0; i < tag_num; ++i) {
          for (size_t j = 0; j < tag_num; ++j) {
            trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
                sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
416
                alpha_mat(k - 1, i) * tmp_mat(k, j) * ll_grad;
C
caoying03 已提交
417 418 419
          }
        }
        trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
420
                   label_value[k]] -= static_cast<T>(ll_grad);
C
caoying03 已提交
421 422
      }
    }
C
caoying03 已提交
423
  }
C
caoying03 已提交
424 425 426 427
};

}  // namespace operators
}  // namespace paddle