target_wrapper.h 4.7 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <iostream>
S
superjomn 已提交
17
#include <sstream>
S
superjomn 已提交
18 19 20 21

namespace paddle {
namespace lite {

S
superjomn 已提交
22 23 24 25 26
enum class TargetType : int {
  kUnk = 0,
  kHost,
  kX86,
  kCUDA,
S
Superjomn 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
  kAny,  // any target
  kLastAsPlaceHolder,
};
enum class PrecisionType : int {
  kUnk = 0,
  kFloat,
  kInt8,
  kAny,  // any precision
  kLastAsPlaceHolder,
};
enum class DataLayoutType : int {
  kUnk = 0,
  kNCHW,
  kAny,  // any data layout
  kLastAsPlaceHolder,
S
superjomn 已提交
42
};
S
superjomn 已提交
43

S
update  
superjomn 已提交
44
// Some helper macro to get a specific TargetType.
S
superjomn 已提交
45 46
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
S
superjomn 已提交
47 48 49 50
// Some helper macro to get a specific PrecisionType.
#define PRECISION(item__) paddle::lite::PrecisionType::item__
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
S
superjomn 已提交
51

S
superjomn 已提交
52 53 54 55 56
constexpr const int kNumPrecisions =
    PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat);
constexpr const int kNumTargets =
    TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);

S
Superjomn 已提交
57 58
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
                                            "any"};
S
superjomn 已提交
59 60 61 62
static const std::string& TargetToStr(TargetType target) {
  return target2string[static_cast<int>(target)];
}

S
Superjomn 已提交
63
static const std::string precision2string[] = {"unk", "float", "int8", "any"};
S
superjomn 已提交
64 65 66 67
static const std::string& PrecisionToStr(PrecisionType precision) {
  return precision2string[static_cast<int>(precision)];
}

S
Superjomn 已提交
68
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
S
superjomn 已提交
69 70 71 72
static const std::string& DataLayoutToStr(DataLayoutType x) {
  return datalayout2string[static_cast<int>(x)];
}

S
superjomn 已提交
73
/*
S
superjomn 已提交
74 75
 * Place specifies the execution context of a Kernel or input/output for a
 * kernel. It is used to make the analysis of the MIR more clear and accurate.
S
superjomn 已提交
76
 */
S
superjomn 已提交
77
struct Place {
S
superjomn 已提交
78 79 80
  TargetType target{TARGET(kUnk)};
  PrecisionType precision{PRECISION(kUnk)};
  DataLayoutType layout{DATALAYOUT(kUnk)};
S
Superjomn 已提交
81
  short device{0};  // device ID
S
superjomn 已提交
82

S
superjomn 已提交
83 84
  Place() = default;
  Place(TargetType target, PrecisionType precision,
S
Superjomn 已提交
85 86 87
        DataLayoutType layout = DATALAYOUT(kNCHW), short device = 0)
      : target(target), precision(precision), layout(layout), device(device) {}

S
superjomn 已提交
88 89 90 91 92 93 94
  bool is_valid() const {
    return target != TARGET(kUnk) && precision != PRECISION(kUnk) &&
           layout != DATALAYOUT(kUnk);
  }

  size_t hash() const;

S
Superjomn 已提交
95 96 97 98
  bool operator==(const Place& other) const {
    return target == other.target && precision == other.precision &&
           layout == other.layout && device == other.device;
  }
S
superjomn 已提交
99

S
Superjomn 已提交
100 101
  bool operator!=(const Place& other) const { return !(*this == other); }

S
Superjomn 已提交
102
  friend bool operator<(const Place& a, const Place& b);
S
superjomn 已提交
103

S
Superjomn 已提交
104 105 106 107 108
  friend std::ostream& operator<<(std::ostream& os, const Place& other) {
    os << other.DebugString();
    return os;
  }

S
Superjomn 已提交
109
  std::string DebugString() const;
S
superjomn 已提交
110
};
S
superjomn 已提交
111

S
superjomn 已提交
112 113
// Memory copy directions.
enum class IoDirection {
S
update  
superjomn 已提交
114 115 116
  HtoH = 0,  // Host to host
  HtoD,      // Host to device
  DtoH,      // Device to host
S
Superjomn 已提交
117
  DtoD,      // Device to device
S
superjomn 已提交
118 119 120
};

// This interface should be specified by each kind of target.
S
Superjomn 已提交
121
template <TargetType Target, typename StreamTy = int, typename EventTy = int>
S
superjomn 已提交
122 123
class TargetWrapper {
 public:
S
Superjomn 已提交
124 125
  using stream_t = StreamTy;
  using event_t = EventTy;
S
superjomn 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

  static size_t num_devices() { return 0; }
  static size_t maximum_stream() { return 0; }

  static void CreateStream(stream_t* stream) {}
  static void DestroyStream(const stream_t& stream) {}

  static void CreateEvent(event_t* event) {}
  static void DestroyEvent(const event_t& event) {}

  static void RecordEvent(const event_t& event) {}
  static void SyncEvent(const event_t& event) {}

  static void StreamSync(const stream_t& stream) {}

S
superjomn 已提交
141 142
  static void* Malloc(size_t size) { return new char[size]; }
  static void Free(void* ptr) { delete[] static_cast<char*>(ptr); }
S
superjomn 已提交
143

S
Superjomn 已提交
144 145 146 147
  static void MemcpySync(void* dst, const void* src, size_t size,
                         IoDirection dir) {}
  static void MemcpyAsync(void* dst, const void* src, size_t size,
                          IoDirection dir, const stream_t& stream) {
S
superjomn 已提交
148 149 150 151 152 153
    MemcpySync(dst, src, size, dir);
  }
};

}  // namespace lite
}  // namespace paddle