place.h 4.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#pragma once
15

16 17 18
// #include <functional>
// #include <iostream>
// #include <vector>
19

Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
21
//
22 23 24
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#endif
Y
Yi Wang 已提交
25

26
#include "paddle/phi/common/place.h"
Y
Yi Wang 已提交
27 28
namespace paddle {
namespace platform {
Y
Yi Wang 已提交
29

30 31 32 33 34 35 36 37 38 39
using Place = phi::Place;
using CPUPlace = phi::CPUPlace;
using CUDAPlace = phi::GPUPlace;
using CUDAPinnedPlace = phi::GPUPinnedPlace;
using NPUPlace = phi::NPUPlace;
using NPUPinnedPlace = phi::NPUPinnedPlace;
using XPUPlace = phi::XPUPlace;
using IPUPlace = phi::IPUPlace;
using MLUPlace = phi::MLUPlace;
using CustomPlace = phi::CustomPlace;
Y
Yi Wang 已提交
40

Y
Yang Yu 已提交
41 42
using PlaceList = std::vector<Place>;

43 44 45 46 47 48 49 50 51
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class PlaceHelper {
 public:
  static std::string GetDeviceType(const Place &place);
  static size_t GetDeviceId(const Place &place);
  static Place CreatePlace(const std::string &dev_type, size_t dev_id = 0);
};
#endif

Y
Yi Wang 已提交
52
bool is_gpu_place(const Place &);
53
bool is_xpu_place(const Place &);
54
bool is_npu_place(const Place &);
F
fwenguang 已提交
55
bool is_mlu_place(const Place &);
J
jianghaicheng 已提交
56
bool is_ipu_place(const Place &);
Y
Yi Wang 已提交
57
bool is_cpu_place(const Place &);
C
chengduoZH 已提交
58
bool is_cuda_pinned_place(const Place &);
59
bool is_npu_pinned_place(const Place &);
60
bool is_custom_place(const Place &p);
Y
Yi Wang 已提交
61
bool places_are_same_class(const Place &, const Place &);
62
bool is_same_place(const Place &, const Place &);
Y
Yi Wang 已提交
63

Y
Yang Yu 已提交
64
template <typename Visitor>
65 66 67
typename Visitor::result_type VisitPlace(const Place &place,
                                         const Visitor &visitor) {
  switch (place.GetType()) {
68
    case phi::AllocationType::GPU: {
69 70 71
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPlace p(place.GetDeviceId());
      return visitor(p);
72
#else
73 74 75
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
76
#endif
77
    }
78
    case phi::AllocationType::GPUPINNED: {
79 80 81
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPinnedPlace p;
      return visitor(p);
82
#else
83 84 85
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
86
#endif
87
    }
88
    case phi::AllocationType::XPU: {
89 90 91
#ifdef PADDLE_WITH_XPU
      platform::XPUPlace p(place.GetDeviceId());
      return visitor(p);
92
#else
93 94 95
      PADDLE_THROW(paddle::platform::errors::Unavailable(
          "Paddle is not compiled with XPU. Cannot visit xpu device"));
      return typename Visitor::result_type();
J
jianghaicheng 已提交
96
#endif
97
    }
98
    case phi::AllocationType::NPU: {
99 100 101
#ifdef PADDLE_WITH_ASCEND_CL
      platform::NPUPlace p(place.GetDeviceId());
      return visitor(p);
F
fwenguang 已提交
102
#else
103 104 105
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
      return typename Visitor::result_type();
F
fwenguang 已提交
106
#endif
107
    }
108
    case phi::AllocationType::NPUPINNED: {
109 110 111
#ifdef PADDLE_WITH_ASCEND_CL
      platform::NPUPinnedPlace p;
      return visitor(p);
J
jianghaicheng 已提交
112
#else
113 114 115
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
      return typename Visitor::result_type();
116
#endif
117
    }
118
    case phi::AllocationType::IPU: {
119 120 121
#ifdef PADDLE_WITH_IPU
      platform::IPUPlace p(place.GetDeviceId());
      return visitor(p);
Y
Yang Yu 已提交
122
#else
123 124 125
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with IPU. Cannot visit ipu device"));
      return typename Visitor::result_type();
Y
Yang Yu 已提交
126
#endif
127
    }
128
    case phi::AllocationType::MLU: {
129 130 131
#ifdef PADDLE_WITH_MLU
      platform::MLUPlace p(place.GetDeviceId());
      return visitor(p);
C
chengduoZH 已提交
132
#else
133 134
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with MLU. Cannot visit mlu device"));
135 136
#endif
    }
137
    case phi::AllocationType::CUSTOM: {
138 139 140 141 142 143
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      platform::CustomPlace p(place.GetDeviceType(), place.GetDeviceId());
      return visitor(p);
#else
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUSTOM. Cannot visit custom device"));
C
chengduoZH 已提交
144
#endif
145 146 147 148 149
    }
    default: {
      platform::CPUPlace p;
      return visitor(p);
    }
C
chengduoZH 已提交
150
  }
Y
Yang Yu 已提交
151 152
}

Y
Yi Wang 已提交
153 154
}  // namespace platform
}  // namespace paddle