context_pool.h 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include <mutex>

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/flat_hash_map.h"

namespace phi {
class DeviceContext;
class CPUContext;
class GPUContext;
}  // namespace phi

namespace paddle {
namespace experimental {

template <AllocationType T>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<AllocationType::CPU> {
  using TYPE = phi::CPUContext;
};

template <>
struct DefaultDeviceContextType<AllocationType::GPU> {
  using TYPE = phi::GPUContext;
};

/**
 * The DeviceContextPool here is just a mirror of the DeviceContextPool in
 * fluid, and does not manage the life cycle of the DeviceContext.
 * It is mainly used for external custom operator calls and high-performance
 * C++ APIs.
 *
 * Since DeviceContextPool in fluid is a global singleton, it always exists
 * in program running, so DeviceContextPool here can always access the correct
 * DeviceContext pointer.
 *
 * In order not to depend on the fluid's DeviceContextPool,
 * the DeviceContextPool here needs to be initialized in the fluid, and cannot
 * be initialized by itself.
 */
class DeviceContextPool {
 public:
  static DeviceContextPool& Instance();

63
  const phi::DeviceContext* Get(const Place& place);
64 65 66 67

  phi::DeviceContext* GetMutable(const Place& place);

  template <AllocationType T>
68
  const typename DefaultDeviceContextType<T>::TYPE* Get(const Place& place) {
69 70 71 72 73
    return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
        Get(place));
  }

 private:
74 75
  DeviceContextPool() = default;

76 77
  paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
      context_map_;
78
  std::mutex mutex_;
79 80 81 82 83 84

  DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

}  // namespace experimental
}  // namespace paddle