gru.cc 4.1 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h>  // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void GRUJitCode::genCode() {
  reg64_t reg_ptr_gates = rax;
  reg64_t reg_ptr_ht_1 = r9;
  reg64_t reg_ptr_ht = r10;
  mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
  mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
  mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
  ymm_t ymm_one = ymm_t(0);

  if (id_ == 2) {
    reg64_t reg_ptr_tmp = r11;
    mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
    vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
  }
  int offset = 0;
  int d = num_ * sizeof(float);
  for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
    ymm_t ymm_u = ymm_t(1);
    ymm_t ymm_r = ymm_t(2);
    ymm_t ymm_s = ymm_t(3);
    ymm_t ymm_ht_1 = ymm_t(4);
    // W: {W_update, W_reset; W_state}
    if (id_ == 0 || id_ == 2) {
      vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
      vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
    }
    if (id_ == 1) {
      vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
    }
    if (id_ == 1 || id_ == 2) {
      vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
    }

    if (id_ == 0) {
      // ht = act_gate(u) * act_cand(s)
      act<ymm_t>(ymm_u, ymm_u, act_gate_);
      act<ymm_t>(ymm_s, ymm_s, act_cand_);
      vmulps(ymm_s, ymm_s, ymm_u);
      vmovups(ptr[reg_ptr_ht + offset], ymm_s);
    } else if (id_ == 1) {
      // ht = act_gate(r) * ht_1
      act<ymm_t>(ymm_r, ymm_r, act_gate_);
      vmulps(ymm_r, ymm_r, ymm_ht_1);
      vmovups(ptr[reg_ptr_ht + offset], ymm_r);
    } else if (id_ == 2) {
      // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
      ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
      act<ymm_t>(ymm_u, ymm_u, act_gate_);
      act<ymm_t>(ymm_s, ymm_s, act_cand_);
      vmulps(ymm_s, ymm_s, ymm_u);
      vsubps(ymm_u, ymm_one_inner, ymm_u);
      vmulps(ymm_u, ymm_ht_1, ymm_u);
      vaddps(ymm_u, ymm_s, ymm_u);
      vmovups(ptr[reg_ptr_ht + offset], ymm_u);
    }
    offset += sizeof(float) * YMM_FLOAT_BLOCK;
  }
  ret();
}

#define DECLARE_GRU_CREATOR(name)                                 \
  class name##Creator : public JitCodeCreator<gru_attr_t> {       \
   public:                                                        \
    /* TODO(TJ): enable more */                                   \
    bool UseMe(const gru_attr_t& attr) const override {           \
      return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
    }                                                             \
    size_t CodeSize(const gru_attr_t& attr) const override {      \
      return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8;          \
    }                                                             \
    std::unique_ptr<GenBase> CreateJitCode(                       \
        const gru_attr_t& attr) const override {                  \
      return make_unique<name##JitCode>(attr, CodeSize(attr));    \
    }                                                             \
  }

DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);

#undef DECLARE_GRU_CREATOR

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(gruh1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(gruhtpart2, gen::GRUHtPart2Creator);