seqpool.h 6.5 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#pragma once

#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
T
tensor-tang 已提交
20
#include "paddle/fluid/platform/enforce.h"
T
tensor-tang 已提交
21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

class SeqPoolJitCode : public JitCode {
 public:
  explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
                          size_t code_size = 256 * 1024,
                          void* code_ptr = nullptr)
T
tensor-tang 已提交
32
      : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
33 34
    if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
          type_ == SeqPoolType::kSqrt)) {
T
tensor-tang 已提交
35 36
      LOG(FATAL) << "Only support sum pool yet ";
    }
37
    fp_h_[0] = 1.f;
T
tensor-tang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    this->genCode();
  }

  virtual const char* name() const {
    std::string base = "SeqPoolJitCode";
    if (type_ == SeqPoolType::kSum) {
      base += "_Sum";
    } else if (type_ == SeqPoolType::kAvg) {
      base += "_Avg";
    } else if (type_ == SeqPoolType::kSqrt) {
      base += "_Sqrt";
    }
    base += ("_W" + std::to_string(w_));
    return base.c_str();
  }
  void genCode() override;

 protected:
  template <typename JMM>
  void pool_height(int w_offset, int block, int max_num_regs) {
T
tensor-tang 已提交
58 59
    int offset = w_offset;
    for (int i = 0; i < max_num_regs; ++i) {
T
tensor-tang 已提交
60
      vmovups(JMM(i), ptr[param_src + offset]);
T
tensor-tang 已提交
61 62
      offset += sizeof(float) * block;
    }
T
tensor-tang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76
    cmp(reg32_int_h, 1);
    Label l_next_h, l_h_done;
    jle(l_h_done, T_NEAR);
    mov(reg_h_i, 1);
    mov(reg_tmp, param_src);
    add(reg_tmp, w_ * sizeof(float) + w_offset);
    L(l_next_h);
    {
      mov(reg_ptr_src_i, reg_tmp);
      for (int i = 0; i < max_num_regs; ++i) {
        vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
        // sum anyway
        vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
        add(reg_ptr_src_i, sizeof(float) * block);
T
tensor-tang 已提交
77
      }
T
tensor-tang 已提交
78 79 80 81
      inc(reg_h_i);
      add(reg_tmp, w_ * sizeof(float));
      cmp(reg_h_i, reg32_int_h);
      jl(l_next_h, T_NEAR);
T
tensor-tang 已提交
82
    }
T
tensor-tang 已提交
83
    L(l_h_done);
T
tensor-tang 已提交
84 85
    // save right now
    if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
86 87
      mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
      vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
T
tensor-tang 已提交
88
    }
T
tensor-tang 已提交
89
    offset = w_offset;
T
tensor-tang 已提交
90 91 92 93
    for (int i = 0; i < max_num_regs; ++i) {
      if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
        vmulps(JMM(i), JMM(i), JMM(max_num_regs));
      }
T
tensor-tang 已提交
94
      vmovups(ptr[param_dst + offset], JMM(i));
T
tensor-tang 已提交
95 96 97 98
      offset += sizeof(float) * block;
    }
  }

T
tensor-tang 已提交
99 100 101 102 103
  void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
    const int rest_used_num_regs = load_rest(rest, w_offset, 0);
    const bool has_block4 = rest / 4 > 0;
    const bool has_block2 = (rest % 4) / 2 > 0;
    const bool has_block1 = (rest % 2) == 1;
T
tensor-tang 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    cmp(reg32_int_h, 1);
    Label l_next_h, l_h_done;
    jle(l_h_done, T_NEAR);
    mov(reg_h_i, 1);
    mov(reg_tmp, param_src);
    add(reg_tmp, w_ * sizeof(float) + w_offset);
    L(l_next_h);
    {
      int reg_idx = 0;
      mov(reg_ptr_src_i, reg_tmp);
      if (has_block4) {
        vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
        add(reg_ptr_src_i, sizeof(float) * 4);
        reg_idx++;
      }
      if (has_block2) {
        vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
        add(reg_ptr_src_i, sizeof(float) * 2);
        reg_idx++;
      }
      if (has_block1) {
        vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
        reg_idx++;
T
tensor-tang 已提交
127
      }
T
tensor-tang 已提交
128 129 130 131 132 133 134 135 136
      PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
                        "All heights should use same regs");
      for (int i = 0; i < reg_idx; ++i) {
        vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
      }
      inc(reg_h_i);
      add(reg_tmp, w_ * sizeof(float));
      cmp(reg_h_i, reg32_int_h);
      jl(l_next_h, T_NEAR);
T
tensor-tang 已提交
137
    }
T
tensor-tang 已提交
138
    L(l_h_done);
T
tensor-tang 已提交
139 140
    // save right now
    if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
141 142
      mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
      vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
T
tensor-tang 已提交
143
      for (int i = 0; i < rest_used_num_regs; ++i) {
T
tensor-tang 已提交
144
        vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
T
tensor-tang 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157
      }
    }
    save_rest(rest, w_offset);
  }

  // return the number of used regs, use start from reg 0
  int load_rest(int rest, int w_offset, const int num_shift_regs,
                const int reg_start = 0) {
    const bool has_block4 = rest / 4 > 0;
    const bool has_block2 = (rest % 4) / 2 > 0;
    const bool has_block1 = (rest % 2) == 1;
    int reg_idx = reg_start;
    if (has_block4) {
T
tensor-tang 已提交
158
      vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
T
tensor-tang 已提交
159 160 161 162
      w_offset += sizeof(float) * 4;
      reg_idx++;
    }
    if (has_block2) {
T
tensor-tang 已提交
163
      vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
T
tensor-tang 已提交
164 165 166 167
      w_offset += sizeof(float) * 2;
      reg_idx++;
    }
    if (has_block1) {
T
tensor-tang 已提交
168
      vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
T
tensor-tang 已提交
169 170 171 172 173 174 175 176 177 178 179 180
      reg_idx++;
    }
    return reg_idx;
  }

  // use reg start from 0
  void save_rest(int rest, int w_offset, int reg_start = 0) {
    const bool has_block4 = rest / 4 > 0;
    const bool has_block2 = (rest % 4) / 2 > 0;
    const bool has_block1 = (rest % 2) == 1;
    int reg_idx = reg_start;
    if (has_block4) {
T
tensor-tang 已提交
181
      vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
T
tensor-tang 已提交
182 183 184 185
      w_offset += sizeof(float) * 4;
      reg_idx++;
    }
    if (has_block2) {
T
tensor-tang 已提交
186
      vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
T
tensor-tang 已提交
187 188 189 190
      w_offset += sizeof(float) * 2;
      reg_idx++;
    }
    if (has_block1) {
T
tensor-tang 已提交
191
      vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
T
tensor-tang 已提交
192 193 194
    }
  }

T
tensor-tang 已提交
195
 private:
196
  float ALIGN32_BEG fp_h_[1] ALIGN32_END;
T
tensor-tang 已提交
197 198
  int w_;
  SeqPoolType type_;
T
tensor-tang 已提交
199 200 201 202 203 204 205
  reg64_t param_src{abi_param1};
  reg64_t param_dst{abi_param2};
  reg64_t param_attr{abi_param3};
  reg64_t reg_tmp{rax};

  reg32_t reg32_int_h{r8d};
  reg32_t reg32_fp_h{r9d};
T
tensor-tang 已提交
206

T
tensor-tang 已提交
207 208
  reg64_t reg_h_i{r10};
  reg64_t reg_ptr_src_i{r11};
T
tensor-tang 已提交
209 210 211 212 213 214
};

}  // namespace gen
}  // namespace jit
}  // namespace operators
}  // namespace paddle