kernel_key.cc 2.7 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. */

#include "paddle/fluid/operators/jit/kernel_key.h"
T
tensor-tang 已提交
16
#include "paddle/fluid/platform/enforce.h"
T
tensor-tang 已提交
17 18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace jit {

template <>
size_t JitCodeKey<int>(const int& d) {
  return d;
}

27 28 29 30 31
template <>
size_t JitCodeKey<int64_t>(const int64_t& d) {
  return d;
}

T
tensor-tang 已提交
32
// TODO(TJ): refine and benchmark JitCodeKey generatation
33
constexpr int act_type_shift = 3;  // suppot 2^3 act types
T
tensor-tang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
static inline int act_type_convert(KernelType type) {
  if (type == kVIdentity) {
    return 0;
  } else if (type == kVExp) {
    return 1;
  } else if (type == kVRelu) {
    return 2;
  } else if (type == kVSigmoid) {
    return 3;
  } else if (type == kVTanh) {
    return 4;
  }
  PADDLE_THROW("Unsupported act type %d", type);
  return 0;
}
49

T
tensor-tang 已提交
50 51 52
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
  size_t key = attr.d;
T
tensor-tang 已提交
53 54 55
  int gate_key = act_type_convert(attr.act_gate) << 1;
  int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
  int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
T
tensor-tang 已提交
56 57 58
  return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
         attr.use_peephole;
}
59 60 61 62

template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
  size_t key = attr.d;
T
tensor-tang 已提交
63 64
  return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
         (act_type_convert(attr.act_cand) << act_type_shift);
65 66
}

T
tensor-tang 已提交
67 68
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
T
tensor-tang 已提交
69
  size_t key = attr.w;
T
tensor-tang 已提交
70 71
  constexpr int pool_type_shift = 3;
  return (key << pool_type_shift) + static_cast<int>(attr.type);
T
tensor-tang 已提交
72 73
}

74 75 76 77 78 79 80
template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
  size_t key = attr.m;
  constexpr int shift = 21;
  return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
}

81 82 83 84 85
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
  return attr.table_width;
}

86 87 88 89 90
template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
  return attr.grad_width;
}

T
tensor-tang 已提交
91 92 93
}  // namespace jit
}  // namespace operators
}  // namespace paddle