place.h 4.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#pragma once
15

16 17 18
// #include <functional>
// #include <iostream>
// #include <vector>
19

Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
21
// #include "paddle/fluid/platform/variant.h"
22 23 24
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#endif
Y
Yi Wang 已提交
25

26
#include "paddle/pten/common/place.h"
Y
Yi Wang 已提交
27 28
namespace paddle {
namespace platform {
Y
Yi Wang 已提交
29

30 31 32 33 34 35 36 37 38
using Place = pten::Place;
using CPUPlace = pten::CPUPlace;
using CUDAPlace = pten::GPUPlace;
using CUDAPinnedPlace = pten::GPUPinnedPlace;
using NPUPlace = pten::NPUPlace;
using NPUPinnedPlace = pten::NPUPinnedPlace;
using XPUPlace = pten::XPUPlace;
using IPUPlace = pten::IPUPlace;
using MLUPlace = pten::MLUPlace;
Y
Yi Wang 已提交
39

Y
Yang Yu 已提交
40 41
using PlaceList = std::vector<Place>;

Y
Yi Wang 已提交
42
bool is_gpu_place(const Place &);
43
bool is_xpu_place(const Place &);
44
bool is_npu_place(const Place &);
F
fwenguang 已提交
45
bool is_mlu_place(const Place &);
J
jianghaicheng 已提交
46
bool is_ipu_place(const Place &);
Y
Yi Wang 已提交
47
bool is_cpu_place(const Place &);
C
chengduoZH 已提交
48
bool is_cuda_pinned_place(const Place &);
49
bool is_npu_pinned_place(const Place &);
Y
Yi Wang 已提交
50
bool places_are_same_class(const Place &, const Place &);
51
bool is_same_place(const Place &, const Place &);
Y
Yi Wang 已提交
52

Y
Yang Yu 已提交
53
template <typename Visitor>
54 55 56 57 58 59 60
typename Visitor::result_type VisitPlace(const Place &place,
                                         const Visitor &visitor) {
  switch (place.GetType()) {
    case pten::AllocationType::GPU: {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPlace p(place.GetDeviceId());
      return visitor(p);
61
#else
62 63 64
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
65
#endif
66 67 68 69 70
    }
    case pten::AllocationType::GPUPINNED: {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPinnedPlace p;
      return visitor(p);
71
#else
72 73 74
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
75
#endif
76 77 78 79 80
    }
    case pten::AllocationType::XPU: {
#ifdef PADDLE_WITH_XPU
      platform::XPUPlace p(place.GetDeviceId());
      return visitor(p);
81
#else
82 83 84
      PADDLE_THROW(paddle::platform::errors::Unavailable(
          "Paddle is not compiled with XPU. Cannot visit xpu device"));
      return typename Visitor::result_type();
J
jianghaicheng 已提交
85
#endif
86 87 88 89 90
    }
    case pten::AllocationType::NPU: {
#ifdef PADDLE_WITH_ASCEND_CL
      platform::NPUPlace p(place.GetDeviceId());
      return visitor(p);
F
fwenguang 已提交
91
#else
92 93 94
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
      return typename Visitor::result_type();
F
fwenguang 已提交
95
#endif
96 97 98 99 100
    }
    case pten::AllocationType::NPUPINNED: {
#ifdef PADDLE_WITH_ASCEND_CL
      platform::NPUPinnedPlace p;
      return visitor(p);
J
jianghaicheng 已提交
101
#else
102 103 104
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
      return typename Visitor::result_type();
105
#endif
106 107 108 109 110
    }
    case pten::AllocationType::IPU: {
#ifdef PADDLE_WITH_IPU
      platform::IPUPlace p(place.GetDeviceId());
      return visitor(p);
Y
Yang Yu 已提交
111
#else
112 113 114
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with IPU. Cannot visit ipu device"));
      return typename Visitor::result_type();
Y
Yang Yu 已提交
115
#endif
116 117 118 119 120
    }
    case pten::AllocationType::MLU: {
#ifdef PADDLE_WITH_MLU
      platform::MLUPlace p(place.GetDeviceId());
      return visitor(p);
C
chengduoZH 已提交
121
#else
122 123
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with MLU. Cannot visit mlu device"));
C
chengduoZH 已提交
124
#endif
125 126 127 128 129
    }
    default: {
      platform::CPUPlace p;
      return visitor(p);
    }
C
chengduoZH 已提交
130
  }
Y
Yang Yu 已提交
131 132
}

Y
Yi Wang 已提交
133 134
}  // namespace platform
}  // namespace paddle