infershaped_utils.h 2.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
21
namespace kernel {
22 23
namespace infershaped {

24
using KeyType = const ::pten::DenseTensor&;
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
using CountType = uint8_t;

constexpr CountType value(std::true_type) { return 1; }

constexpr CountType value(std::false_type) { return 0; }

template <typename T>
constexpr CountType value() {
  return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}

template <typename FirstArg>
constexpr CountType count(CountType num) {
  return num;
}

template <typename FirstArg>
constexpr CountType count() {
  return 0;
}

template <>
constexpr CountType count<KeyType>(CountType num) {
  return num + 1;
}

template <>
constexpr CountType count<KeyType>() {
  return 1;
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
  return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
  return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}

}  // namespace infershaped

template <typename F>
struct InferShapeHelper;

template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
  static constexpr int count = infershaped::count<Args...>();
};

76
}  // namespace kernel
77
}  // namespace infrt