heter_resource.h 3.6 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cstddef>
#include <map>
#include <memory>
#include <vector>
20 21

#ifdef PADDLE_WITH_CUDA
T
Thunderbrook 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23 24 25 26 27 28 29
#endif

#ifdef PADDLE_WITH_XPU_KP
#include <xpu/runtime.h>  // NOLINT
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif

T
Thunderbrook 已提交
30 31
#include "paddle/fluid/platform/enforce.h"

T
Thunderbrook 已提交
32
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
33 34 35 36

namespace paddle {
namespace framework {

37 38 39 40 41 42 43
#if defined(PADDLE_WITH_CUDA)
using ppStream = cudaStream_t;
#elif defined(PADDLE_WITH_XPU_KP)
using ppStream = XPUStream;
#endif

#if defined(PADDLE_WITH_CUDA)
T
Thunderbrook 已提交
44 45
class GPUResource {
 public:
46
  GPUResource(std::vector<int>& device_id, int index);  // NOLINT
T
Thunderbrook 已提交
47 48 49 50 51 52
  virtual ~GPUResource();
  GPUResource(const GPUResource&) = delete;
  GPUResource& operator=(const GPUResource&) = delete;

  int dev_id() const { return dev_id_; }
  int index() const { return index_; }
53
  gpuStream_t local_stream(int num) { return local_streams_[num]; }
54
  gpuStream_t remote_stream(int num) { return remote_streams_[num]; }
55
  gpuStream_t comm_stream(int num) { return comm_streams_[num]; }
T
Thunderbrook 已提交
56 57 58

  int dev_id_;
  int index_;
59
  std::vector<int> dev_ids_;
60
  std::vector<gpuStream_t> remote_streams_;
61 62
  std::vector<gpuStream_t> local_streams_;
  std::vector<gpuStream_t> comm_streams_;
T
Thunderbrook 已提交
63
};
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
#elif defined(PADDLE_WITH_XPU_KP)
class XPUResource {
 public:
  XPUResource(std::vector<int>& device_id, int index);  // NOLINT
  virtual ~XPUResource();
  XPUResource(const XPUResource&) = delete;
  XPUResource& operator=(const XPUResource&) = delete;

  int dev_id() const { return dev_id_; }
  int index() const { return index_; }
  XPUStream local_stream(int num) { return local_streams_[num]; }
  XPUStream remote_stream(int num) { return remote_streams_[num]; }
  XPUStream comm_stream(int num) { return comm_streams_[num]; }

  int dev_id_;
  int index_;
  std::vector<int> dev_ids_;
  std::vector<XPUStream> remote_streams_;
  std::vector<XPUStream> local_streams_;
  std::vector<XPUStream> comm_streams_;
};
#endif

#if defined(PADDLE_WITH_CUDA)
using DevResource = GPUResource;
using DevPlace = platform::CUDAPlace;
using AnyDeviceGuard = platform::CUDADeviceGuard;
#elif defined(PADDLE_WITH_XPU_KP)
using DevResource = XPUResource;
using DevPlace = platform::XPUPlace;
using AnyDeviceGuard = platform::XPUDeviceGuard;
#endif
T
Thunderbrook 已提交
96 97 98

class HeterPsResource {
 public:
99
  explicit HeterPsResource(const std::vector<int>& dev_ids);
T
Thunderbrook 已提交
100 101 102 103
  HeterPsResource(const HeterPsResource&) = delete;
  HeterPsResource& operator=(const HeterPsResource&) = delete;
  virtual ~HeterPsResource() {}
  void enable_p2p();
104
  int total_device();
T
Thunderbrook 已提交
105 106
  int get_index_by_devid(int devid);
  int dev_id(int num);
Y
yaoxuefeng 已提交
107
  void set_multi_mf(int multi_mf_dim, int max_mf_dim);
108 109 110
  ppStream local_stream(int dev_num, int stream_num);
  ppStream remote_stream(int dev_num, int stream_num);
  ppStream comm_stream(int dev_num, int stream_num);
T
Thunderbrook 已提交
111

112
  std::vector<std::shared_ptr<DevResource>> resources_;
T
Thunderbrook 已提交
113 114
  std::vector<int> dev_ids_;
  std::map<int, int> devid_2_index_;
Y
yaoxuefeng 已提交
115 116
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
T
Thunderbrook 已提交
117 118 119 120 121
};

}  // end namespace framework
}  // end namespace paddle
#endif