heter_resource.h 3.7 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cstddef>
#include <map>
#include <memory>
#include <vector>
20 21

#ifdef PADDLE_WITH_CUDA
T
Thunderbrook 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23 24 25 26 27 28 29
#endif

#ifdef PADDLE_WITH_XPU_KP
#include <xpu/runtime.h>  // NOLINT
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif

T
Thunderbrook 已提交
30 31
#include "paddle/fluid/platform/enforce.h"

T
Thunderbrook 已提交
32
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
33 34 35 36

namespace paddle {
namespace framework {

37 38
#if defined(PADDLE_WITH_CUDA)
using ppStream = cudaStream_t;
F
Fan Zhang 已提交
39

40 41 42 43 44
#elif defined(PADDLE_WITH_XPU_KP)
using ppStream = XPUStream;
#endif

#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
45 46
class GPUResource {
 public:
47
  GPUResource(std::vector<int>& device_id, int index);  // NOLINT
T
Thunderbrook 已提交
48 49 50 51 52 53
  virtual ~GPUResource();
  GPUResource(const GPUResource&) = delete;
  GPUResource& operator=(const GPUResource&) = delete;

  int dev_id() const { return dev_id_; }
  int index() const { return index_; }
54
  gpuStream_t local_stream(int num) { return local_streams_[num]; }
55
  gpuStream_t remote_stream(int num) { return remote_streams_[num]; }
56
  gpuStream_t comm_stream(int num) { return comm_streams_[num]; }
T
Thunderbrook 已提交
57 58 59

  int dev_id_;
  int index_;
60
  std::vector<int> dev_ids_;
61
  std::vector<gpuStream_t> remote_streams_;
62 63
  std::vector<gpuStream_t> local_streams_;
  std::vector<gpuStream_t> comm_streams_;
T
Thunderbrook 已提交
64
};
F
Fan Zhang 已提交
65

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
#elif defined(PADDLE_WITH_XPU_KP)
class XPUResource {
 public:
  XPUResource(std::vector<int>& device_id, int index);  // NOLINT
  virtual ~XPUResource();
  XPUResource(const XPUResource&) = delete;
  XPUResource& operator=(const XPUResource&) = delete;

  int dev_id() const { return dev_id_; }
  int index() const { return index_; }
  XPUStream local_stream(int num) { return local_streams_[num]; }
  XPUStream remote_stream(int num) { return remote_streams_[num]; }
  XPUStream comm_stream(int num) { return comm_streams_[num]; }

  int dev_id_;
  int index_;
  std::vector<int> dev_ids_;
  std::vector<XPUStream> remote_streams_;
  std::vector<XPUStream> local_streams_;
  std::vector<XPUStream> comm_streams_;
};
#endif

#if defined(PADDLE_WITH_CUDA)
using DevResource = GPUResource;
using DevPlace = platform::CUDAPlace;
using AnyDeviceGuard = platform::CUDADeviceGuard;
#elif defined(PADDLE_WITH_XPU_KP)
using DevResource = XPUResource;
using DevPlace = platform::XPUPlace;
using AnyDeviceGuard = platform::XPUDeviceGuard;
#endif
T
Thunderbrook 已提交
98 99 100

class HeterPsResource {
 public:
101
  explicit HeterPsResource(const std::vector<int>& dev_ids);
T
Thunderbrook 已提交
102 103 104 105
  HeterPsResource(const HeterPsResource&) = delete;
  HeterPsResource& operator=(const HeterPsResource&) = delete;
  virtual ~HeterPsResource() {}
  void enable_p2p();
106
  int total_device();
T
Thunderbrook 已提交
107 108
  int get_index_by_devid(int devid);
  int dev_id(int num);
Y
yaoxuefeng 已提交
109
  void set_multi_mf(int multi_mf_dim, int max_mf_dim);
Y
yaoxuefeng 已提交
110 111
  int multi_mf() { return multi_mf_dim_; }
  int max_mf_dim() { return max_mf_dim_; }
F
Fan Zhang 已提交
112

113 114 115
  ppStream local_stream(int dev_num, int stream_num);
  ppStream remote_stream(int dev_num, int stream_num);
  ppStream comm_stream(int dev_num, int stream_num);
T
Thunderbrook 已提交
116

117
  std::vector<std::shared_ptr<DevResource>> resources_;
T
Thunderbrook 已提交
118 119
  std::vector<int> dev_ids_;
  std::map<int, int> devid_2_index_;
Y
yaoxuefeng 已提交
120 121
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
T
Thunderbrook 已提交
122 123 124 125 126
};

}  // end namespace framework
}  // end namespace paddle
#endif