xpu_context.cc 4.5 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/backends/xpu/xpu_context.h"
W
Wilber 已提交
16

W
Wilber 已提交
17
#include <memory>
W
Wilber 已提交
18

19 20
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/place.h"
W
Wilber 已提交
21 22 23 24 25 26
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
#include "xpu/xdnn.h"

namespace xpu = baidu::xpu::api;

27
namespace phi {
W
Wilber 已提交
28

W
Wilber 已提交
29 30
struct XPUContext::Impl {
  void SetL3Cache(int l3_size = 14155776) {
W
Wilber 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
    const int MAX_XPU_NUM = 16;
    static void* l3ptrs[MAX_XPU_NUM] = {nullptr};

    if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) {
      l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE"));
    }

    auto selected_xpus = backends::xpu::GetXPUSelectedDevices();
    for (unsigned int i = 0; i < selected_xpus.size(); i++) {
      if (place_.GetDeviceId() == selected_xpus[i]) {
        if (l3ptrs[place_.GetDeviceId()] == nullptr) {
          xpu_malloc(static_cast<void**>(&l3ptrs[place_.GetDeviceId()]),
                     l3_size,
                     XPU_MEM_L3);
        }
        if (l3ptrs[place_.GetDeviceId()] != nullptr) {
          context_->_l3_mgr.set(l3ptrs[place_.GetDeviceId()], l3_size);
48 49
          VLOG(3) << "xpu place " << static_cast<int>(place_.GetDeviceId())
                  << " set l3 size " << l3_size;
W
Wilber 已提交
50 51 52 53 54 55
        }
        break;
      }
    }
  }

W
Wilber 已提交
56
  Impl() : place_(XPUPlace()) {}
W
Wilber 已提交
57

W
Wilber 已提交
58
  explicit Impl(const Place& place) : place_(place) {}
W
Wilber 已提交
59

W
Wilber 已提交
60 61
  ~Impl() {
    if (owned_ && context_ != nullptr) {
W
Wilber 已提交
62 63 64 65 66
      xpu::destroy_context(context_);
      context_ = nullptr;
    }
  }

W
Wilber 已提交
67
  const Place& GetPlace() const { return place_; }
W
Wilber 已提交
68

69 70
  void SetStream(XPUStream stream) { context_->xpu_stream = stream; }

71 72 73 74 75 76
  XPUStream stream() const {
    auto s = context_->xpu_stream;
    PD_CHECK(s != nullptr, "the xpu stream is nullptr.");
    return s;
  }

W
Wilber 已提交
77 78 79 80 81
  xpu::Context* GetXContext() const {
    PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
    return context_;
  }

W
Wilber 已提交
82 83 84 85
  xpu::BKCLContext_t GetBkclContext() const {
    PD_CHECK(bkcl_context_ != nullptr, "the xpu bkcl_context is nullptr.");
    return bkcl_context_;
  }
W
Wilber 已提交
86 87 88 89 90 91 92

  void Wait() const {
    backends::xpu::SetXPUDeviceId(place_.GetDeviceId());
    PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
    xpu_wait(context_->xpu_stream);
  }

W
Wilber 已提交
93 94 95
  void Init() {
    owned_ = true;
    backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
96 97
    LOG_FIRST_N(WARNING, 1)
        << "Please NOTE: xpu device: " << static_cast<int>(place_.device);
W
Wilber 已提交
98 99 100
    context_ = xpu::create_context();
    xpu_version_ = backends::xpu::get_xpu_version(place_.device);
    SetL3Cache();
W
Wilber 已提交
101 102
  }

W
Wilber 已提交
103 104
  void SetXContext(xpu::Context* context) { context_ = context; }

W
Wilber 已提交
105 106
  void SetBkclContext(xpu::BKCLContext_t context) { bkcl_context_ = context; }

W
Wilber 已提交
107 108
  bool owned_{false};
  Place place_;
W
Wilber 已提交
109 110
  backends::xpu::XPUVersion xpu_version_;
  xpu::Context* context_{nullptr};
W
Wilber 已提交
111

W
Wilber 已提交
112 113 114 115 116
  // NOTE: Distributed communicator, distributed framework manages its
  // resources, XPUContext only holds references.
  xpu::BKCLContext_t bkcl_context_{nullptr};
};

W
Wilber 已提交
117
XPUContext::XPUContext() : DeviceContext(), impl_(std::make_unique<Impl>()) {}
W
Wilber 已提交
118

W
Wilber 已提交
119 120
XPUContext::XPUContext(const XPUPlace& place)
    : DeviceContext(), impl_(std::make_unique<Impl>(place)) {}
W
Wilber 已提交
121 122 123

XPUContext::~XPUContext() = default;

W
Wilber 已提交
124
const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }
W
Wilber 已提交
125

126 127
void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); }

128 129
XPUStream XPUContext::stream() const { return impl_->stream(); }

W
Wilber 已提交
130
backends::xpu::XPUVersion XPUContext::xpu_version() const {
W
Wilber 已提交
131
  return impl_->xpu_version_;
W
Wilber 已提交
132 133 134 135 136 137 138 139 140 141
}

xpu::Context* XPUContext::x_context() const { return impl_->GetXContext(); }

xpu::BKCLContext_t XPUContext::bkcl_context() const {
  return impl_->GetBkclContext();
}

void XPUContext::Wait() const { impl_->Wait(); }

W
Wilber 已提交
142
void XPUContext::SetXContext(xpu::Context* context) {
W
Wilber 已提交
143 144 145
  impl_->SetXContext(context);
}

W
Wilber 已提交
146 147 148
void XPUContext::SetL3Cache(int l3_size) { impl_->SetL3Cache(l3_size); }

void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
W
Wilber 已提交
149 150 151
  impl_->SetBkclContext(context);
}

W
Wilber 已提交
152 153
void XPUContext::Init() { impl_->Init(); }

154
}  // namespace phi