lstm.cc 4.8 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h>  // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void LSTMJitCode::genCode() {
  if (use_peephole_) {
    preCode();
  }
  reg64_t reg_ptr_gates = rax;
  reg64_t reg_ptr_ct_1 = r9;
  reg64_t reg_ptr_ct = r10;
  reg64_t reg_ptr_ht = r11;
  reg64_t reg_ptr_wp = r12;
  mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
  mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
  mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
  mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
  if (use_peephole_) {
    mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
  }

  int offset = 0;
  int d = num_ * sizeof(float);
  for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
    /* gates: W_ch, W_ih, W_fh, W_oh */
    ymm_t ymm_c = ymm_t(0);
    ymm_t ymm_i = ymm_t(1);
    ymm_t ymm_f = ymm_t(2);
    ymm_t ymm_o = ymm_t(3);
    ymm_t ymm_ct_1 = ymm_t(4);
    ymm_t ymm_wp0 = ymm_t(5);
    ymm_t ymm_wp1 = ymm_t(6);
    ymm_t ymm_wp2 = ymm_t(7);
    vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
    vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
    vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
    vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
    if (!compute_c1h1_) {
      vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
    }
    if (use_peephole_) {
      vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
      vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
      vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
    }
    /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
    // act_cand(c)
    act<ymm_t>(ymm_c, ymm_c, act_cand_);
    // act_gate(i) or act_gate(ct_1 * wp0 + i)
    if (!compute_c1h1_ && use_peephole_) {
      vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
      vaddps(ymm_i, ymm_i, ymm_wp0);
    }
    act<ymm_t>(ymm_i, ymm_i, act_gate_);
    vmulps(ymm_c, ymm_c, ymm_i);
    if (!compute_c1h1_) {
      // act_gate(f) or act_gate(ct_1 * wp1 + f)
      if (use_peephole_) {
        vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
        vaddps(ymm_f, ymm_f, ymm_wp1);
      }
      act<ymm_t>(ymm_f, ymm_f, act_gate_);
      // ct
      vmulps(ymm_f, ymm_f, ymm_ct_1);
      vaddps(ymm_f, ymm_f, ymm_c);
    }
    /* H_t = act_cell(C_t) * act_gate(o) */
    // act_cell(C_t)
    ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
    ymm_t ymm_tmp = ymm_i;
    act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
    // act_gate(o) or act_gate(ct * wp2 + o)
    if (use_peephole_) {
      vmulps(ymm_wp2, ymm_ct, ymm_wp2);
      vaddps(ymm_o, ymm_o, ymm_wp2);
    }
    act<ymm_t>(ymm_o, ymm_o, act_gate_);
    // ht
    vmulps(ymm_o, ymm_o, ymm_tmp);
    // save ct and ht
    vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
    vmovups(ptr[reg_ptr_ht + offset], ymm_o);
    offset += sizeof(float) * YMM_FLOAT_BLOCK;
  }

  if (use_peephole_) {
    postCode();
  } else {
    ret();
  }
}

#define DECLARE_LSTM_CREATOR(name)                                \
  class name##Creator : public JitCodeCreator<lstm_attr_t> {      \
   public:                                                        \
    /* TODO(TJ): enable more */                                   \
117
    bool CanBeUsed(const lstm_attr_t& attr) const override {      \
T
tensor-tang 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
      return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
    }                                                             \
    size_t CodeSize(const lstm_attr_t& attr) const override {     \
      return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;          \
    }                                                             \
    std::unique_ptr<GenBase> CreateJitCode(                       \
        const lstm_attr_t& attr) const override {                 \
      return make_unique<name##JitCode>(attr, CodeSize(attr));    \
    }                                                             \
  }

DECLARE_LSTM_CREATOR(LSTMCtHt);
DECLARE_LSTM_CREATOR(LSTMC1H1);

#undef DECLARE_LSTM_CREATOR

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle

namespace gen = paddle::operators::jit::gen;

T
tensor-tang 已提交
141 142
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);