xpu_context.cc 9.0 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/backends/xpu/xpu_context.h"
W
Wilber 已提交
16

W
Wilber 已提交
17
#include <memory>
W
Wilber 已提交
18

19 20
#include "glog/logging.h"

21
#include "paddle/phi/api/ext/exception.h"
J
james 已提交
22
#include "paddle/phi/backends/xpu/enforce_xpu.h"
23
#include "paddle/phi/common/place.h"
24
#include "paddle/phi/core/os_info.h"
W
Wilber 已提交
25 26 27 28 29 30
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
#include "xpu/xdnn.h"

namespace xpu = baidu::xpu::api;

31
namespace phi {
W
Wilber 已提交
32

W
Wilber 已提交
33 34
struct XPUContext::Impl {
  void SetL3Cache(int l3_size = 14155776) {
W
Wilber 已提交
35 36 37 38 39 40 41 42 43 44
    const int MAX_XPU_NUM = 16;
    static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

    if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
      l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
    }

    auto selected_xpus = backends::xpu::GetXPUSelectedDevices();
    for (unsigned int i = 0; i < selected_xpus.size(); i++) {
      if (place_.GetDeviceId() == selected_xpus[i]) {
45 46 47
        if (l3ptrs[place_.GetDeviceId()] != nullptr) {
          xpu_free(l3ptrs[place_.GetDeviceId()]);
          l3ptrs[place_.GetDeviceId()] = nullptr;
W
Wilber 已提交
48
        }
49 50 51
        xpu_malloc(static_cast<void**>(&l3ptrs[place_.GetDeviceId()]),
                   l3_size,
                   XPU_MEM_L3);
W
Wilber 已提交
52 53
        if (l3ptrs[place_.GetDeviceId()] != nullptr) {
          context_->_l3_mgr.set(l3ptrs[place_.GetDeviceId()], l3_size);
54 55
          VLOG(3) << "xpu place " << static_cast<int>(place_.GetDeviceId())
                  << " set l3 size " << l3_size;
W
Wilber 已提交
56 57 58 59 60 61
        }
        break;
      }
    }
  }

62 63 64 65 66
  bool IsDataloader() const {
    if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) {
      return false;
    }
    std::string cur_thread_name = phi::GetCurrentThreadName();
67 68
    VLOG(3) << "XPU Dataloader: current thread at Get Context = "
            << phi::GetCurrentThreadName();
69
    bool is_dataloader_thread = (cur_thread_name != "MainThread");
70 71 72
    return is_dataloader_thread;
  }

W
Wilber 已提交
73
  Impl() : place_(XPUPlace()) {}
W
Wilber 已提交
74

W
Wilber 已提交
75
  explicit Impl(const Place& place) : place_(place) {}
W
Wilber 已提交
76

W
Wilber 已提交
77 78
  ~Impl() {
    if (owned_ && context_ != nullptr) {
J
james 已提交
79 80
      backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
      xpu_wait(context_->xpu_stream);
C
csy0225 已提交
81
      if (context_->xpu_stream && stream_owned_) {
82 83 84 85 86
        // manually destroy XPUStream here until xpu::api integrates this work
        // into Context dtor
        xpu_stream_destroy(context_->xpu_stream);
        context_->xpu_stream = nullptr;
      }
W
Wilber 已提交
87 88 89
      xpu::destroy_context(context_);
      context_ = nullptr;
    }
90 91 92 93 94 95 96 97 98 99 100 101
    if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) {
      // destroy all XPU Dataloader threads if exist
      backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
      for (auto ctx : GetAllXdlCtxs()) {
        xpu_wait(ctx->xpu_stream);
        if (ctx->xpu_stream) {
          xpu_stream_destroy(ctx->xpu_stream);
          ctx->xpu_stream = nullptr;
        }
        xpu::destroy_context(ctx);
        ctx = nullptr;
      }
102
      xdl_context_map_.clear();
103
    }
W
Wilber 已提交
104 105
  }

W
Wilber 已提交
106
  const Place& GetPlace() const { return place_; }
W
Wilber 已提交
107

108 109
  XPUStream stream() const {
    if (IsDataloader()) {
110
      xpu::Context* ctx_t = GetXdlCtx();
111 112 113 114
      return ctx_t->xpu_stream;
    }
    return context_->xpu_stream;
  }
115

C
csy0225 已提交
116 117 118 119 120 121
  // Set external stream for context
  void SetStream(void* stream) {
    stream_owned_ = false;
    context_->set_stream(static_cast<XPUStream>(stream));
  }

W
Wilber 已提交
122 123 124 125 126
  xpu::Context* GetXContext() const {
    PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
    return context_;
  }

W
Wilber 已提交
127 128 129 130
  xpu::BKCLContext_t GetBkclContext() const {
    PD_CHECK(bkcl_context_ != nullptr, "the xpu bkcl_context is nullptr.");
    return bkcl_context_;
  }
W
Wilber 已提交
131

132 133 134 135
  // Overload GetXContext function to set and get
  // contexts of XPU Dataloader threads, and keep old GetXContext Method
  xpu::Context* GetXContext() {
    if (IsDataloader()) {
136 137
      SetXdlCtx();
      xpu::Context* ctx_t = GetXdlCtx();
138 139 140 141 142 143 144 145
      PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
      return ctx_t;
    }

    PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
    return context_;
  }

W
Wilber 已提交
146
  void Wait() const {
147
    if (IsDataloader()) {
148 149 150 151 152 153
      xpu::Context* ctx_t = GetXdlCtx();
      if (ctx_t) {
        PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
        xpu_wait(ctx_t->xpu_stream);
      }
      return;
154
    }
155

156 157 158
    backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
    PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
    xpu_wait(context_->xpu_stream);
159 160 161 162 163
    xpu::Context* ctx_t = GetXdlCtx();
    if (ctx_t) {
      PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
      xpu_wait(ctx_t->xpu_stream);
    }
164 165
  }

W
Wilber 已提交
166 167 168
  void Init() {
    owned_ = true;
    backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
169 170
    LOG_FIRST_N(WARNING, 1)
        << "Please NOTE: xpu device: " << static_cast<int>(place_.device);
W
Wilber 已提交
171
    context_ = xpu::create_context();
172 173 174 175
    if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) {
      // Initialize XPU Dataloader threads contexts map
      InitializeXdlContexts();
    }
W
Wilber 已提交
176 177
    xpu_version_ = backends::xpu::get_xpu_version(place_.device);
    SetL3Cache();
W
Wilber 已提交
178 179
  }

W
Wilber 已提交
180 181
  void SetXContext(xpu::Context* context) { context_ = context; }

W
Wilber 已提交
182 183
  void SetBkclContext(xpu::BKCLContext_t context) { bkcl_context_ = context; }

184 185 186 187 188 189
  void CreateStream() {
    if (context_->xpu_stream) {
      VLOG(3) << "xpu stream is already created for current context";
      return;
    }
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
C
csy0225 已提交
190
    stream_owned_ = true;
191 192
  }

193 194 195 196 197 198 199 200 201 202 203
  // Methods of XPU Dataloader threads contexts map,
  // currently, need set 'export XPU_PADDLE_XDL_CONTEXTS=1'
  // to open XPU Dataloader context map
  void InitializeXdlContexts() {
    if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) {
      return;
    }
    auto thread_map = phi::GetAllThreadNames();
    for (const auto& tp : thread_map) {
      std::string t_name = tp.second;
      if (t_name.substr(0, 10) == "Dataloader") {
204
        SetXdlCtx();
205 206 207 208
      }
    }
  }

209 210 211
  void SetXdlCtx() {
    auto pid = phi::GetProcessId();
    if (xdl_context_map_.find(pid) == xdl_context_map_.end()) {
212
      xpu::Context* ctx_t = xpu::create_context();
213
      xdl_context_map_[pid] = ctx_t;
214 215 216
    }
  }

217 218 219
  xpu::Context* GetXdlCtx() const {
    auto pid = phi::GetProcessId();
    return (xdl_context_map_.find(pid) == xdl_context_map_.end())
220
               ? nullptr
221
               : xdl_context_map_.find(pid)->second;
222 223 224 225 226 227 228 229 230 231
  }

  std::vector<xpu::Context*> GetAllXdlCtxs() {
    std::vector<xpu::Context*> ctxs;
    for (const auto& it : xdl_context_map_) {
      ctxs.emplace_back(it.second);
    }
    return ctxs;
  }

W
Wilber 已提交
232
  bool owned_{false};
C
csy0225 已提交
233
  bool stream_owned_{false};
W
Wilber 已提交
234
  Place place_;
W
Wilber 已提交
235
  backends::xpu::XPUVersion xpu_version_;
C
csy0225 已提交
236 237
  int runtime_version_;
  int driver_version_;
W
Wilber 已提交
238
  xpu::Context* context_{nullptr};
239
  std::unordered_map<uint32_t, xpu::Context*> xdl_context_map_;
W
Wilber 已提交
240

W
Wilber 已提交
241 242 243 244 245
  // NOTE: Distributed communicator, distributed framework manages its
  // resources, XPUContext only holds references.
  xpu::BKCLContext_t bkcl_context_{nullptr};
};

246 247 248
XPUContext::XPUContext() : DeviceContext(), impl_(std::make_unique<Impl>()) {
  impl_->Init();
}
W
Wilber 已提交
249

W
Wilber 已提交
250
XPUContext::XPUContext(const XPUPlace& place)
251 252 253
    : DeviceContext(), impl_(std::make_unique<Impl>(place)) {
  impl_->Init();
}
W
Wilber 已提交
254 255 256

XPUContext::~XPUContext() = default;

W
Wilber 已提交
257
const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }
W
Wilber 已提交
258

259 260
XPUStream XPUContext::stream() const { return impl_->stream(); }

C
csy0225 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274
void XPUContext::SetStream(void* stream) { impl_->SetStream(stream); }

void XPUContext::SetXpuVersion(int version) {
  impl_->xpu_version_ = static_cast<backends::xpu::XPUVersion>(version);
}

void XPUContext::SetRuntimeVersion(int version) {
  impl_->runtime_version_ = version;
}

void XPUContext::SetDriverVersion(int version) {
  impl_->driver_version_ = version;
}

W
Wilber 已提交
275
backends::xpu::XPUVersion XPUContext::xpu_version() const {
W
Wilber 已提交
276
  return impl_->xpu_version_;
W
Wilber 已提交
277 278 279 280 281 282 283 284 285 286
}

xpu::Context* XPUContext::x_context() const { return impl_->GetXContext(); }

xpu::BKCLContext_t XPUContext::bkcl_context() const {
  return impl_->GetBkclContext();
}

void XPUContext::Wait() const { impl_->Wait(); }

W
Wilber 已提交
287
void XPUContext::SetXContext(xpu::Context* context) {
W
Wilber 已提交
288 289 290
  impl_->SetXContext(context);
}

W
Wilber 已提交
291 292
void XPUContext::SetL3Cache(int l3_size) { impl_->SetL3Cache(l3_size); }

293 294
bool XPUContext::IsDataloader() const { return impl_->IsDataloader(); }

W
Wilber 已提交
295
void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
W
Wilber 已提交
296 297 298
  impl_->SetBkclContext(context);
}

299 300
void XPUContext::CreateStream() { impl_->CreateStream(); }

W
Wilber 已提交
301
void XPUContext::Init() { impl_->Init(); }
302
}  // namespace phi