kernel_key.cc 2.0 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/kernel_key.h"

namespace paddle {
namespace operators {
namespace jit {

template <>
size_t JitCodeKey<int>(const int& d) {
  return d;
}

26 27
constexpr int act_type_shift = 3;  // suppot 2^3 act types

T
tensor-tang 已提交
28 29 30 31 32 33 34 35 36
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
  size_t key = attr.d;
  int gate_key = static_cast<int>(attr.act_gate) << 1;
  int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
  int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
  return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
         attr.use_peephole;
}
37 38 39 40 41 42 43 44

template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
  size_t key = attr.d;
  return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
         (static_cast<int>(attr.act_cand) << act_type_shift);
}

T
tensor-tang 已提交
45 46
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
T
tensor-tang 已提交
47
  size_t key = attr.w;
T
tensor-tang 已提交
48 49
  constexpr int pool_type_shift = 3;
  return (key << pool_type_shift) + static_cast<int>(attr.type);
T
tensor-tang 已提交
50 51
}

52 53 54 55 56 57 58
template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
  size_t key = attr.m;
  constexpr int shift = 21;
  return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
}

59 60 61 62 63
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
  return attr.table_width;
}

T
tensor-tang 已提交
64 65 66
}  // namespace jit
}  // namespace operators
}  // namespace paddle