place.h 3.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#pragma once
15

16 17 18
// #include <functional>
// #include <iostream>
// #include <vector>
19

Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
21
//
Y
Yi Wang 已提交
22

23
#include "paddle/phi/common/place.h"
Y
Yi Wang 已提交
24 25
namespace paddle {
namespace platform {
Y
Yi Wang 已提交
26

27 28 29 30 31 32 33 34
using Place = phi::Place;
using CPUPlace = phi::CPUPlace;
using CUDAPlace = phi::GPUPlace;
using CUDAPinnedPlace = phi::GPUPinnedPlace;
using NPUPinnedPlace = phi::NPUPinnedPlace;
using XPUPlace = phi::XPUPlace;
using IPUPlace = phi::IPUPlace;
using CustomPlace = phi::CustomPlace;
Y
Yi Wang 已提交
35

Y
Yang Yu 已提交
36 37
using PlaceList = std::vector<Place>;

38 39 40 41 42 43 44 45 46
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class PlaceHelper {
 public:
  static std::string GetDeviceType(const Place &place);
  static size_t GetDeviceId(const Place &place);
  static Place CreatePlace(const std::string &dev_type, size_t dev_id = 0);
};
#endif

Y
Yi Wang 已提交
47
bool is_gpu_place(const Place &);
48
bool is_xpu_place(const Place &);
J
jianghaicheng 已提交
49
bool is_ipu_place(const Place &);
Y
Yi Wang 已提交
50
bool is_cpu_place(const Place &);
C
chengduoZH 已提交
51
bool is_cuda_pinned_place(const Place &);
52
bool is_custom_place(const Place &p);
Y
Yi Wang 已提交
53
bool places_are_same_class(const Place &, const Place &);
54
bool is_same_place(const Place &, const Place &);
Y
Yi Wang 已提交
55

Y
Yang Yu 已提交
56
template <typename Visitor>
57 58 59
typename Visitor::result_type VisitPlace(const Place &place,
                                         const Visitor &visitor) {
  switch (place.GetType()) {
60
    case phi::AllocationType::GPU: {
61 62 63
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPlace p(place.GetDeviceId());
      return visitor(p);
64
#else
65 66 67
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
68
#endif
69
    }
70
    case phi::AllocationType::GPUPINNED: {
71 72 73
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      platform::CUDAPinnedPlace p;
      return visitor(p);
74
#else
75 76 77
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUDA. Cannot visit cuda_pinned"));
      return typename Visitor::result_type();
78
#endif
79
    }
80
    case phi::AllocationType::XPU: {
81 82 83
#ifdef PADDLE_WITH_XPU
      platform::XPUPlace p(place.GetDeviceId());
      return visitor(p);
84
#else
85 86 87
      PADDLE_THROW(paddle::platform::errors::Unavailable(
          "Paddle is not compiled with XPU. Cannot visit xpu device"));
      return typename Visitor::result_type();
J
jianghaicheng 已提交
88
#endif
89
    }
90
    case phi::AllocationType::NPUPINNED: {
91 92 93 94
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with NPU. Cannot visit npu_pinned"));
      return typename Visitor::result_type();
    }
95
    case phi::AllocationType::IPU: {
96 97 98
#ifdef PADDLE_WITH_IPU
      platform::IPUPlace p(place.GetDeviceId());
      return visitor(p);
Y
Yang Yu 已提交
99
#else
100 101 102
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with IPU. Cannot visit ipu device"));
      return typename Visitor::result_type();
103 104
#endif
    }
105
    case phi::AllocationType::CUSTOM: {
106 107 108 109 110 111
#ifdef PADDLE_WITH_CUSTOM_DEVICE
      platform::CustomPlace p(place.GetDeviceType(), place.GetDeviceId());
      return visitor(p);
#else
      PADDLE_THROW(platform::errors::Unavailable(
          "Paddle is not compiled with CUSTOM. Cannot visit custom device"));
C
chengduoZH 已提交
112
#endif
113 114 115 116 117
    }
    default: {
      platform::CPUPlace p;
      return visitor(p);
    }
C
chengduoZH 已提交
118
  }
Y
Yang Yu 已提交
119 120
}

Y
Yi Wang 已提交
121 122
}  // namespace platform
}  // namespace paddle