target_wrapper.h 5.1 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <iostream>
S
superjomn 已提交
17
#include <sstream>
S
superjomn 已提交
18
#include <string>
C
Chunwei 已提交
19
#include "paddle/fluid/lite/api/place.h"
20
#include "paddle/fluid/lite/utils/cp_logging.h"
C
Chunwei 已提交
21

S
superjomn 已提交
22 23 24 25
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
S
superjomn 已提交
26 27 28 29

namespace paddle {
namespace lite {

C
Chunwei 已提交
30 31 32 33 34 35 36 37 38 39 40
using lite_api::TargetType;
using lite_api::PrecisionType;
using lite_api::DataLayoutType;
using lite_api::PrecisionTypeLength;
using lite_api::TargetToStr;
using lite_api::Place;
using lite_api::PrecisionToStr;
using lite_api::DataLayoutToStr;
using lite_api::TargetRepr;
using lite_api::PrecisionRepr;
using lite_api::DataLayoutRepr;
S
superjomn 已提交
41

S
superjomn 已提交
42 43
// Memory copy directions.
enum class IoDirection {
S
update  
superjomn 已提交
44 45 46
  HtoH = 0,  // Host to host
  HtoD,      // Host to device
  DtoH,      // Device to host
S
Superjomn 已提交
47
  DtoD,      // Device to device
S
superjomn 已提交
48 49 50
};

// This interface should be specified by each kind of target.
S
Superjomn 已提交
51
template <TargetType Target, typename StreamTy = int, typename EventTy = int>
S
superjomn 已提交
52 53
class TargetWrapper {
 public:
S
Superjomn 已提交
54 55
  using stream_t = StreamTy;
  using event_t = EventTy;
S
superjomn 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

  static size_t num_devices() { return 0; }
  static size_t maximum_stream() { return 0; }

  static void CreateStream(stream_t* stream) {}
  static void DestroyStream(const stream_t& stream) {}

  static void CreateEvent(event_t* event) {}
  static void DestroyEvent(const event_t& event) {}

  static void RecordEvent(const event_t& event) {}
  static void SyncEvent(const event_t& event) {}

  static void StreamSync(const stream_t& stream) {}

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  static void* Malloc(size_t size) {
    LOG(FATAL) << "Unimplemented malloc for " << TargetToStr(Target);
    return nullptr;
  }
  static void Free(void* ptr) { LOG(FATAL) << "Unimplemented"; }

  static void MemcpySync(void* dst, const void* src, size_t size,
                         IoDirection dir) {
    LOG(FATAL) << "Unimplemented";
  }
  static void MemcpyAsync(void* dst, const void* src, size_t size,
                          IoDirection dir, const stream_t& stream) {
    MemcpySync(dst, src, size, dir);
  }
};

// This interface should be specified by each kind of target.
S
Superjomn 已提交
88 89
using TargetWrapperHost = TargetWrapper<TARGET(kHost)>;
using TargetWrapperX86 = TargetWrapperHost;
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
template <>
class TargetWrapper<TARGET(kHost)> {
 public:
  using stream_t = int;
  using event_t = int;

  static size_t num_devices() { return 0; }
  static size_t maximum_stream() { return 0; }

  static void CreateStream(stream_t* stream) {}
  static void DestroyStream(const stream_t& stream) {}

  static void CreateEvent(event_t* event) {}
  static void DestroyEvent(const event_t& event) {}

  static void RecordEvent(const event_t& event) {}
  static void SyncEvent(const event_t& event) {}

  static void StreamSync(const stream_t& stream) {}

  static void* Malloc(size_t size);
  static void Free(void* ptr);
S
superjomn 已提交
112

S
Superjomn 已提交
113
  static void MemcpySync(void* dst, const void* src, size_t size,
114
                         IoDirection dir);
S
Superjomn 已提交
115 116
  static void MemcpyAsync(void* dst, const void* src, size_t size,
                          IoDirection dir, const stream_t& stream) {
S
superjomn 已提交
117 118 119 120
    MemcpySync(dst, src, size, dir);
  }
};

S
superjomn 已提交
121
#ifdef LITE_WITH_CUDA
S
Superjomn 已提交
122 123
using TargetWrapperCuda =
    TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
S
superjomn 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
 public:
  using stream_t = cudaStream_t;
  using event_t = cudaEvent_t;

  static size_t num_devices() { return 0; }
  static size_t maximum_stream() { return 0; }

  static void CreateStream(stream_t* stream) {}
  static void DestroyStream(const stream_t& stream) {}

  static void CreateEvent(event_t* event) {}
  static void DestroyEvent(const event_t& event) {}

  static void RecordEvent(const event_t& event) {}
  static void SyncEvent(const event_t& event) {}

  static void StreamSync(const stream_t& stream) {}

  static void* Malloc(size_t size);
  static void Free(void* ptr);

  static void MemcpySync(void* dst, const void* src, size_t size,
                         IoDirection dir);
  static void MemcpyAsync(void* dst, const void* src, size_t size,
                          IoDirection dir, const stream_t& stream);
};
#endif  // LITE_WITH_CUDA

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
template <TargetType Target>
void CopySync(void* dst, void* src, size_t size, IoDirection dir) {
  switch (Target) {
    case TARGET(kX86):
    case TARGET(kHost):
    case TARGET(kARM):
      TargetWrapperX86::MemcpySync(dst, src, size, IoDirection::HtoH);
      break;
#ifdef LITE_WITH_CUDA
    case TARGET(kCUDA):
      TargetWrapperCuda::MemcpySync(dst, src, size, dir);
#endif
  }
}

S
superjomn 已提交
170 171
}  // namespace lite
}  // namespace paddle