From d972de569d272d9f505fb0b9d2a704d867df77f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=B2=A7=E5=A4=9C?= Date: Fri, 31 Mar 2023 18:06:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4paddle/fluid/platform/device/?= =?UTF-8?q?mlu=E7=9B=AE=E5=BD=95=20(#52382)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fluid/platform/device/mlu/CMakeLists.txt | 29 -- .../fluid/platform/device/mlu/cncl_helper.h | 57 --- .../platform/device/mlu/device_context.cc | 86 ---- .../platform/device/mlu/device_context.h | 186 ------- .../device/mlu/device_context_allocator.h | 162 ------ .../device/mlu/device_context_test.cc | 82 ---- paddle/fluid/platform/device/mlu/enforce.h | 167 ------- .../fluid/platform/device/mlu/enforce_test.cc | 72 --- .../device/mlu/mlu_collective_helper.cc | 190 ------- paddle/fluid/platform/device/mlu/mlu_info.cc | 464 ------------------ paddle/fluid/platform/device/mlu/mlu_info.h | 190 ------- .../platform/device/mlu/mlu_resource_pool.cc | 105 ---- .../platform/device/mlu/mlu_resource_pool.h | 64 --- .../fluid/platform/device/mlu/mlu_stream.cc | 84 ---- paddle/fluid/platform/device/mlu/mlu_stream.h | 101 ---- 15 files changed, 2039 deletions(-) delete mode 100644 paddle/fluid/platform/device/mlu/CMakeLists.txt delete mode 100644 paddle/fluid/platform/device/mlu/cncl_helper.h delete mode 100644 paddle/fluid/platform/device/mlu/device_context.cc delete mode 100644 paddle/fluid/platform/device/mlu/device_context.h delete mode 100644 paddle/fluid/platform/device/mlu/device_context_allocator.h delete mode 100644 paddle/fluid/platform/device/mlu/device_context_test.cc delete mode 100644 paddle/fluid/platform/device/mlu/enforce.h delete mode 100644 paddle/fluid/platform/device/mlu/enforce_test.cc delete mode 100644 paddle/fluid/platform/device/mlu/mlu_collective_helper.cc delete mode 100644 paddle/fluid/platform/device/mlu/mlu_info.cc delete mode 100644 paddle/fluid/platform/device/mlu/mlu_info.h delete mode 100644 paddle/fluid/platform/device/mlu/mlu_resource_pool.cc delete mode 100644 paddle/fluid/platform/device/mlu/mlu_resource_pool.h delete mode 100644 paddle/fluid/platform/device/mlu/mlu_stream.cc delete mode 100644 paddle/fluid/platform/device/mlu/mlu_stream.h diff --git a/paddle/fluid/platform/device/mlu/CMakeLists.txt b/paddle/fluid/platform/device/mlu/CMakeLists.txt deleted file mode 100644 index c723eb149b8..00000000000 --- a/paddle/fluid/platform/device/mlu/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -if(NOT WITH_MLU) - return() -endif() - -cc_test(mlu_enforce_test SRCS enforce_test.cc) -cc_library( - mlu_info - SRCS mlu_info.cc - DEPS enforce glog malloc monitor neuware_lib) -cc_library( - mlu_stream - SRCS mlu_stream.cc - DEPS mlu_info stream_callback_manager eigen3 ${MKLDNN_CTX_DEPS}) -cc_library( - mlu_device_context - SRCS device_context.cc - DEPS mlu_stream) -cc_test( - mlu_device_context_test - SRCS device_context_test.cc - DEPS mlu_device_context) -cc_library( - mlu_collective_helper - SRCS mlu_collective_helper.cc - DEPS mlu_stream mlu_info) -cc_library( - mlu_resource_pool - SRCS mlu_resource_pool.cc - DEPS mlu_info) diff --git a/paddle/fluid/platform/device/mlu/cncl_helper.h b/paddle/fluid/platform/device/mlu/cncl_helper.h deleted file mode 100644 index 634e420d5ce..00000000000 --- a/paddle/fluid/platform/device/mlu/cncl_helper.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifdef PADDLE_WITH_CNCL -#include -#include - -#include -#include -#include // NOLINT -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/mlu/enforce.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace platform { - -inline cnclDataType_t ToCNCLDataType(framework::proto::VarType::Type type) { - if (type == framework::proto::VarType::FP32) { - return cnclFloat32; - } else if (type == framework::proto::VarType::FP16) { - return cnclFloat16; - } else if (type == framework::proto::VarType::INT32) { - return cnclInt32; - } else if (type == framework::proto::VarType::INT16) { - return cnclInt16; - } else if (type == framework::proto::VarType::INT8) { - return cnclInt8; - } else if (type == framework::proto::VarType::UINT8) { - return cnclUint8; - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "This datatype in cncl is not supported.")); - } -} - -} // namespace platform -} // namespace paddle -#endif diff --git a/paddle/fluid/platform/device/mlu/device_context.cc b/paddle/fluid/platform/device/mlu/device_context.cc deleted file mode 100644 index 796d7006834..00000000000 --- a/paddle/fluid/platform/device/mlu/device_context.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_MLU -#include "paddle/fluid/platform/device/mlu/device_context.h" -#endif - -namespace paddle { -namespace platform { - -#ifdef PADDLE_WITH_MLU -thread_local std::unordered_map> - MLUDeviceContext::thread_ctx_; -thread_local std::mutex MLUDeviceContext::ctx_mtx_; - -MLUContext::MLUContext(const MLUPlace& place, const int priority) { - place_ = place; - MLUDeviceGuard guard(place_.device); - stream_.reset(new stream::MLUStream(place_, priority)); - InitCNNLContext(); - InitMLUOPContext(); -} - -MLUContext::~MLUContext() { - MLUDeviceGuard guard(place_.device); - DestoryCNNLContext(); - DestoryMLUOPContext(); -} - -MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { - MLUDeviceGuard guard(place_.device); - compute_capability_ = GetMLUComputeCapability(place_.device); - driver_version_ = GetMLUDriverVersion(place_.device); - runtime_version_ = GetMLURuntimeVersion(place_.device); - cnnl_version_ = GetMLUCnnlVersion(place_.device); - mluOp_version_ = GetMLUOpVersion(place_.device); - - LOG_FIRST_N(WARNING, 1) - << "Please NOTE: device: " << static_cast(place_.device) - << ", MLU Compute Capability: " << compute_capability_ / 10 << "." - << compute_capability_ % 10 - << ", Driver API Version: " << driver_version_ / 10000 << "." - << (driver_version_ / 100) % 100 << "." << driver_version_ % 100 - << ", Runtime API Version: " << runtime_version_ / 10000 << "." - << (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100 - << ", Cnnl API Version: " << cnnl_version_ / 10000 << "." - << (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100 - << ", MluOp API Version: " << mluOp_version_ / 10000 << "." - << (mluOp_version_ / 100) % 100 << "." << mluOp_version_ % 100; - - default_ctx_.reset(new MLUContext(place_)); -} - -MLUDeviceContext::~MLUDeviceContext() {} - -const Place& MLUDeviceContext::GetPlace() const { return place_; } - -void MLUDeviceContext::Wait() const { context()->Stream()->Wait(); } - -int MLUDeviceContext::GetComputeCapability() const { - return compute_capability_; -} - -mluCnnlHandle MLUDeviceContext::cnnl_handle() const { - return context()->CnnlHandle(); -} - -mluOpHandle MLUDeviceContext::mluOp_handle() const { - return context()->MluOpHandle(); -} - -mluStream MLUDeviceContext::stream() const { return context()->RawStream(); } - -#endif -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h deleted file mode 100644 index a430e18a34a..00000000000 --- a/paddle/fluid/platform/device/mlu/device_context.h +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#ifdef PADDLE_WITH_MLU -#include - -#include "paddle/fluid/platform/device/mlu/enforce.h" -#include "paddle/fluid/platform/device/mlu/mlu_stream.h" -#include "paddle/fluid/platform/device_context.h" -#ifdef PADDLE_WITH_CNCL -#include -#endif - -namespace Eigen { -struct DefaultDevice; -struct GpuDevice; -} // namespace Eigen - -namespace paddle { -namespace platform { - -class MLUContext { - public: - MLUContext() = default; - explicit MLUContext(const MLUPlace& place, const int priority = 0); - - ~MLUContext(); - - const MLUPlace& Place() const { return place_; } - - const std::unique_ptr& EigenDevice() const { - return eigen_device_; - } - - const std::unique_ptr& Stream() const { return stream_; } - - stream::MLUStream* SetStream(stream::MLUStream* new_stream_ptr) { - auto* old_stream_ptr = stream_.release(); - stream_.reset(new_stream_ptr); - return old_stream_ptr; - } - - const mluStream& RawStream() { return stream_->raw_stream(); } - - const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; } - - const mluOpHandle& MluOpHandle() const { return mluOp_handle_; } - - private: - void InitCNNLContext() { - PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_)); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream())); - } - - void InitMLUOPContext() { - PADDLE_ENFORCE_MLU_SUCCESS(mluOpCreate(&mluOp_handle_)); - PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetQueue(mluOp_handle_, RawStream())); - } - - void DestoryCNNLContext() { - if (cnnl_handle_) { - PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_)); - } - cnnl_handle_ = nullptr; - } - - void DestoryMLUOPContext() { - if (mluOp_handle_) { - PADDLE_ENFORCE_MLU_SUCCESS(mluOpDestroy(mluOp_handle_)); - } - mluOp_handle_ = nullptr; - } - - MLUPlace place_; - std::unique_ptr eigen_device_; - std::unique_ptr stream_; - mluCnnlHandle cnnl_handle_; - mluOpHandle mluOp_handle_; - - DISABLE_COPY_AND_ASSIGN(MLUContext); -}; - -class MLUDeviceContext - : public DeviceContext, - public phi::TypeInfoTraits { - public: - explicit MLUDeviceContext(MLUPlace place); - virtual ~MLUDeviceContext(); - Eigen::DefaultDevice* eigen_device() const { return nullptr; } - const Place& GetPlace() const override; - - int GetComputeCapability() const; - - /*! \brief Wait for all operations completion in the stream. */ - void Wait() const override; - - /*! \brief Return cnnl handle in the device context. */ - mluCnnlHandle cnnl_handle() const; - - /*! \brief Return mluOp handle in the device context. */ - mluOpHandle mluOp_handle() const; - - /*! \brief Return mlu stream in the device context. */ - mluStream stream() const; - -#ifdef PADDLE_WITH_CNCL - /*! \brief Return cncl communicators. */ - cnclComm_t cncl_comm() const { return cncl_comm_; } - - /*! \brief Set cncl communicators. */ - void set_cncl_comm(cnclComm_t comm) { cncl_comm_ = comm; } -#endif - - template - void RecordEvent(mluEventHandle ev, Callback callback) const { - return context()->Stream()->RecordEvent(ev, callback); - } - - template - void AddStreamCallback(Callback&& callback) const { - return context()->Stream()->AddCallback(callback); - } - - void WaitStreamCallback() const { - return context()->Stream()->WaitCallback(); - } - - void ResetDefaultContext(const int priority) { - default_ctx_.reset(new MLUContext(place_, priority)); - } - - void ResetThreadContext(const int priority) { - std::lock_guard guard(ctx_mtx_); - thread_ctx_[this].reset(new MLUContext(place_, priority)); - } - - std::shared_ptr context() const { - if (!thread_ctx_.count(this)) { - return default_ctx_; - } - return thread_ctx_.at(this); - } - - static const char* name() { return "MLUDeviceContext"; } - - private: - int compute_capability_; - int driver_version_; - int runtime_version_; - int cnnl_version_; - int mluOp_version_; - MLUPlace place_; - std::shared_ptr default_ctx_; - - // The thread_local static variable will be released before the - // global static variable, so avoid using it in dtor. - static thread_local std::unordered_map> - thread_ctx_; - static thread_local std::mutex ctx_mtx_; - -#ifdef PADDLE_WITH_CNCL - cnclComm_t cncl_comm_{nullptr}; -#endif - - DISABLE_COPY_AND_ASSIGN(MLUDeviceContext); -}; - -template <> -struct DefaultDeviceContextType { - using TYPE = MLUDeviceContext; -}; - -#endif - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context_allocator.h b/paddle/fluid/platform/device/mlu/device_context_allocator.h deleted file mode 100644 index 706cab4f54a..00000000000 --- a/paddle/fluid/platform/device/mlu/device_context_allocator.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/device/mlu/device_context.h" -#include "paddle/fluid/platform/device/mlu/mlu_info.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { - -namespace platform { -class MLUDeviceContext; -} // namespace platform - -namespace memory { -namespace allocation { - -/** - * MLUDeviceContextAllocation is a wrapper of the underbeneath allocation. - * MLUDeviceContextAllocation adds a MLU stream callback for the underbeneath - * allocation so that MLUDeviceContextAllocation can be used in a MLU stream - * which deletes allocation in the callback. - */ -class MLUDeviceContextAllocation : public Allocation { - public: - explicit MLUDeviceContextAllocation(AllocationPtr allocation) - : Allocation(allocation->ptr(), allocation->size(), allocation->place()), - underlying_allocation_(std::move(allocation)) {} - - ~MLUDeviceContextAllocation() { - PADDLE_ENFORCE_NOT_NULL( - dev_ctx_, - platform::errors::PreconditionNotMet( - "Device context is not set for MLUDeviceContextAllocation")); - auto *p_allocation = underlying_allocation_.release(); - VLOG(4) << "Adding callback to delete MLUDeviceContextAllocation at " - << p_allocation; - dev_ctx_->AddStreamCallback([p_allocation] { - VLOG(4) << "Delete MLUDeviceContextAllocation at " << p_allocation; - Allocator::AllocationDeleter(p_allocation); - }); - } - - void SetMLUDeviceContext(const platform::MLUDeviceContext *dev_ctx) { - dev_ctx_ = dev_ctx; - } - - private: - AllocationPtr underlying_allocation_; - const platform::MLUDeviceContext *dev_ctx_{nullptr}; -}; - -/** - * MLUDeviceContextAllocator will allocate a MLUDeviceContextAllocation - * after waiting for a self-created event on the default stream. It does so to - * let the non-default stream be able to allocate GPU memory which will be - * released by stream callback - */ -class MLUDeviceContextAllocator : public Allocator { - public: - explicit MLUDeviceContextAllocator(platform::MLUPlace place, - mluStream default_stream) - : place_(place), default_stream_(default_stream) { - platform::MLUDeviceGuard guard(place_.device); - PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierCreate(&event_)); - } - - ~MLUDeviceContextAllocator() { - if (event_) { - platform::MLUDeviceGuard guard(place_.device); - PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierDestroy(event_)); - } - } - - protected: - phi::Allocation *AllocateImpl(size_t size) override { - PADDLE_ENFORCE_NOT_NULL( - default_stream_, - platform::errors::PreconditionNotMet( - "Default stream is not set for MLUDeviceContextAllocator")); - platform::MLUDeviceGuard guard(place_.device); - auto allocation = - new MLUDeviceContextAllocation(memory::Alloc(place_, size)); - // Wait for the event on stream - PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event_, default_stream_)); - PADDLE_ENFORCE_MLU_SUCCESS(cnrtWaitNotifier(event_)); - return allocation; - } - - void FreeImpl(phi::Allocation *allocation) override { delete allocation; } - - private: - platform::MLUPlace place_; - mluEventHandle event_{nullptr}; - mluStream default_stream_{nullptr}; -}; - -/** - * MLUDeviceContextAllocatorPool is a singletion stores mapping from - * MLUPlace(s) to std::shared_ptr. When a - * MLUDeviceContext's compute stream isn't default stream, it can call this - * class to allocate GPU memory which will be released by a callback after - * stream execution. - */ -class MLUDeviceContextAllocatorPool { - public: - static MLUDeviceContextAllocatorPool &Instance() { - static MLUDeviceContextAllocatorPool pool; - return pool; - } - - AllocationPtr Alloc(const platform::MLUDeviceContext &dev_ctx, size_t size) { - auto iter = allocators_.find(dev_ctx.GetPlace()); - PADDLE_ENFORCE_NE( - iter, - allocators_.end(), - platform::errors::NotFound("No allocator found for MLUPlace.")); - auto &allocator = iter->second; - AllocationPtr allocation = allocator->Allocate(size); - static_cast(allocation.get()) - ->SetMLUDeviceContext(&dev_ctx); - return allocation; - } - - private: - MLUDeviceContextAllocatorPool() { - std::vector devices = platform::GetMLUSelectedDevices(); - for (int i : devices) { - auto place = platform::MLUPlace(i); - auto compute_stream = - platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); - auto allocator = std::shared_ptr( - new MLUDeviceContextAllocator(place, compute_stream)); - allocators_.insert(make_pair(place, allocator)); - } - } - - std::map> - allocators_; -}; - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context_test.cc b/paddle/fluid/platform/device/mlu/device_context_test.cc deleted file mode 100644 index 41f79c7092e..00000000000 --- a/paddle/fluid/platform/device/mlu/device_context_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/platform/device/mlu/device_context.h" - -#include - -#include "glog/logging.h" -#include "gtest/gtest.h" - -TEST(Device, Init) { - using paddle::platform::DeviceContext; - using paddle::platform::MLUContext; - using paddle::platform::MLUDeviceContext; - using paddle::platform::MLUPlace; - - int count = paddle::platform::GetMLUDeviceCount(); - for (int i = 0; i < count; i++) { - MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); - std::shared_ptr ctx = device_context->context(); - ASSERT_NE(nullptr, ctx); - delete device_context; - } -} - -TEST(Device, MLUDeviceContext) { - using paddle::mluCnnlHandle; - using paddle::platform::MLUDeviceContext; - using paddle::platform::MLUPlace; - - int count = paddle::platform::GetMLUDeviceCount(); - for (int i = 0; i < count; i++) { - MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); - mluCnnlHandle mlu_handle = device_context->cnnl_handle(); - ASSERT_NE(nullptr, mlu_handle); - delete device_context; - } -} - -TEST(Device, MLUStream) { - using paddle::mluStream; - using paddle::platform::MLUDeviceContext; - using paddle::platform::MLUPlace; - - int count = paddle::platform::GetMLUDeviceCount(); - for (int i = 0; i < count; i++) { - MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); - mluStream mlu_stream = device_context->stream(); - ASSERT_NE(nullptr, mlu_stream); - delete device_context; - } -} - -TEST(Device, DeviceContextPool) { - using paddle::platform::CPUPlace; - using paddle::platform::DeviceContextPool; - using paddle::platform::MLUDeviceContext; - using paddle::platform::MLUPlace; - using paddle::platform::Place; - - DeviceContextPool& pool = DeviceContextPool::Instance(); - auto cpu_dev_ctx1 = pool.Get(CPUPlace()); - auto cpu_dev_ctx2 = pool.Get(CPUPlace()); - ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); - - std::vector mlu_places; - int count = paddle::platform::GetMLUDeviceCount(); - for (int i = 0; i < count; ++i) { - auto dev_ctx = pool.Get(MLUPlace(i)); - ASSERT_NE(dev_ctx, nullptr); - } -} diff --git a/paddle/fluid/platform/device/mlu/enforce.h b/paddle/fluid/platform/device/mlu/enforce.h deleted file mode 100644 index 8b0d0bb36f5..00000000000 --- a/paddle/fluid/platform/device/mlu/enforce.h +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/platform/enforce.h" -#ifdef PADDLE_WITH_MLU -#include "paddle/fluid/platform/device/mlu/mlu_info.h" -#endif // PADDLE_WITH_MLU - -#ifdef PADDLE_WITH_MLU -DECLARE_int64(gpu_allocator_retry_time); -#endif - -namespace paddle { -namespace platform { - -#ifdef PADDLE_WITH_MLU -namespace details { -template -struct MLUStatusType {}; - -#define DEFINE_MLU_STATUS_TYPE(type, success_value, proto_type) \ - template <> \ - struct MLUStatusType { \ - using Type = type; \ - static constexpr Type kSuccess = success_value; \ - static constexpr const char* kTypeString = #proto_type; \ - } - -DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); -DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); -DEFINE_MLU_STATUS_TYPE(mluOpStatus, MLUOP_STATUS_SUCCESS, MLUOP); -DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); -#ifdef PADDLE_WITH_CNCL -DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL); -#endif - -} // namespace details - -/*************** CNRT ERROR ***************/ -inline bool is_error(cnrtStatus e) { return e != cnrtSuccess; } - -inline std::string build_mlu_error_msg(cnrtStatus e) { - std::ostringstream sout; - sout << "MLU CNRT error(" << e << "), " << cnrtGetErrorName(e) << ": " - << cnrtGetErrorStr(e); - return sout.str(); -} - -/*************** CNNL ERROR ***************/ -inline bool is_error(cnnlStatus stat) { return stat != CNNL_STATUS_SUCCESS; } - -inline std::string build_mlu_error_msg(cnnlStatus stat) { - std::ostringstream sout; - sout << "MLU CNNL error(" << stat << "), " << cnnlGetErrorString(stat) - << ". "; - return sout.str(); -} - -/*************** MLU OP ERROR ***************/ -inline bool is_error(mluOpStatus stat) { return stat != MLUOP_STATUS_SUCCESS; } - -inline std::string build_mlu_error_msg(mluOpStatus stat) { - std::ostringstream sout; - sout << "MLU OP error(" << stat << "), " << mluOpGetErrorString(stat) << ". "; - return sout.str(); -} - -/*************** CN API ERROR ***************/ -inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; } - -inline std::string build_mlu_error_msg(cnStatus stat) { - const char* error_name; - const char* error_string; - cnGetErrorName(stat, &error_name); - cnGetErrorString(stat, &error_string); - - std::ostringstream sout; - sout << "MLU CN error(" << static_cast(stat) << "), " << error_name - << " : " << error_string << ". "; - return sout.str(); -} - -/*************** CNCL ERROR ***************/ -#ifdef PADDLE_WITH_CNCL -inline bool is_error(cnclStatus e) { return e != CNCL_RET_SUCCESS; } - -inline std::string build_mlu_error_msg(cnclStatus e) { - std::ostringstream sout; - sout << "MLU CNCL error(" << e << "), " << cnclGetErrorStr(e) << ". "; - return sout.str(); -} -#endif - -#define PADDLE_ENFORCE_MLU_SUCCESS(COND) \ - do { \ - auto __cond__ = (COND); \ - using __MLU_STATUS_TYPE__ = decltype(__cond__); \ - constexpr auto __success_type__ = \ - ::paddle::platform::details::MLUStatusType< \ - __MLU_STATUS_TYPE__>::kSuccess; \ - if (UNLIKELY(__cond__ != __success_type__)) { \ - auto __summary__ = ::paddle::platform::errors::External( \ - ::paddle::platform::build_mlu_error_msg(__cond__)); \ - __THROW_ERROR_INTERNAL__(__summary__); \ - } \ - } while (0) - -#define PADDLE_ENFORCE_MLU_LAUNCH_SUCCESS(OP) \ - do { \ - auto res = cnrtGetLastError(); \ - if (UNLIKELY(res != cnrtSuccess)) { \ - auto msg = ::paddle::platform::build_mlu_error_msg(res); \ - PADDLE_THROW(platform::errors::Fatal( \ - "CNRT error after kernel (%s): %s", OP, msg)); \ - } \ - } while (0) - -inline void retry_sleep(unsigned milliseconds) { - if (milliseconds < 1000) { - // usleep argument must be less than 1,000,000. Reference: - // https://pubs.opengroup.org/onlinepubs/7908799/xsh/usleep.html - usleep(milliseconds * 1000); - } else { - // clip to sleep in seconds because we can not and don't have to - // sleep for exact milliseconds - sleep(milliseconds / 1000); - } -} - -#define PADDLE_RETRY_MLU_SUCCESS(COND) \ - do { \ - auto __cond__ = (COND); \ - int retry_count = 1; \ - using __MLU_STATUS_TYPE__ = decltype(__cond__); \ - constexpr auto __success_type__ = \ - ::paddle::platform::details::MLUStatusType< \ - __MLU_STATUS_TYPE__>::kSuccess; \ - while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \ - retry_sleep(FLAGS_gpu_allocator_retry_time); \ - __cond__ = (COND); \ - ++retry_count; \ - } \ - if (UNLIKELY(__cond__ != __success_type__)) { \ - auto __summary__ = ::paddle::platform::errors::External( \ - ::paddle::platform::build_mlu_error_msg(__cond__)); \ - __THROW_ERROR_INTERNAL__(__summary__); \ - } \ - } while (0) - -#undef DEFINE_MLU_STATUS_TYPE -#endif // PADDLE_WITH_MLU - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/enforce_test.cc b/paddle/fluid/platform/device/mlu/enforce_test.cc deleted file mode 100644 index 4ff7b12c446..00000000000 --- a/paddle/fluid/platform/device/mlu/enforce_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/device/mlu/enforce.h" - -#include - -#include "gtest/gtest.h" - -#ifdef PADDLE_WITH_MLU -template -bool CheckMluStatusSuccess(T value, const std::string& msg = "success") { - PADDLE_ENFORCE_MLU_SUCCESS(value); - return true; -} - -template -bool CheckMluStatusFailure(T value, const std::string& msg) { - try { - PADDLE_ENFORCE_MLU_SUCCESS(value); - return false; - } catch (paddle::platform::EnforceNotMet& error) { - std::string ex_msg = error.what(); - std::cout << ex_msg << std::endl; - return ex_msg.find(msg) != std::string::npos; - } -} - -TEST(mlu_enforce, mlu_success) { - EXPECT_TRUE(CheckMluStatusSuccess(cnrtSuccess)); - EXPECT_TRUE(CheckMluStatusFailure(cnrtErrorArgsInvalid, "invalid argument")); - EXPECT_TRUE(CheckMluStatusFailure(cnrtErrorMemcpyDirectionInvalid, - "invalid memcpy direction")); - EXPECT_TRUE( - CheckMluStatusFailure(cnrtErrorDeviceInvalid, "invalid device ordinal")); - - EXPECT_TRUE(CheckMluStatusSuccess(CNNL_STATUS_SUCCESS)); - EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_NOT_INITIALIZED, "CNNL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_ALLOC_FAILED, "CNNL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_BAD_PARAM, "CNNL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_INTERNAL_ERROR, "CNNL error")); - - EXPECT_TRUE(CheckMluStatusSuccess(CN_SUCCESS)); - EXPECT_TRUE(CheckMluStatusFailure( - CN_ERROR_NOT_READY, - "Asynchronous operations issued previously not completed yet")); - EXPECT_TRUE( - CheckMluStatusFailure(CN_ERROR_NOT_INITIALIZED, "initialization error")); - EXPECT_TRUE( - CheckMluStatusFailure(CN_ERROR_INVALID_VALUE, "invalid argument")); - EXPECT_TRUE(CheckMluStatusFailure(CN_MEMORY_ERROR_OUT_OF_MEMORY, - "device has no memory to alloc")); -#ifdef PADDLE_WITH_CNCL - EXPECT_TRUE(CheckMluStatusSuccess(CNCL_RET_SUCCESS)); - EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INTERNAL, "CNCL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NULL_POINTER, "CNCL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INIT, "CNCL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NOT_INIT, "CNCL error")); - EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_REINIT, "CNCL error")); - EXPECT_TRUE( - CheckMluStatusFailure(CNCL_RET_ERR_INVALID_VERSION, "CNCL error")); -#endif -} -#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc b/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc deleted file mode 100644 index cb98f73f1e9..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#if defined(PADDLE_WITH_CNCL) -#include - -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/mlu/enforce.h" - -namespace paddle { -namespace platform { - -class CNCLCommImpl : public CNCLComm { - public: - void set_ring_id(int ring_id) { ring_id_ = ring_id; } - int ring_id() const override { return ring_id_; } - - void set_nranks(int nranks) { nranks_ = nranks; } - int nranks() const override { return nranks_; } - - void set_rank(int rank) { rank_ = rank; } - int rank() const override { return rank_; } - - int device_id() const override { return dev_ctx_->GetPlace().device; } - - void set_comm(cnclComm_t comm) { comm_ = comm; } - cnclComm_t comm() const override { return comm_; } - - mluStream stream() const override { return dev_ctx_->stream(); } - - void set_dev_ctx(std::unique_ptr&& dev_ctx) { - dev_ctx_ = std::move(dev_ctx); - } - MLUDeviceContext* dev_context() const override { return dev_ctx_.get(); } - - ~CNCLCommImpl() { - if (comm_) { - PADDLE_ENFORCE_MLU_SUCCESS(cnclFreeComm(comm_)); - } - } - - private: - int ring_id_; - int nranks_; - int rank_; - cnclComm_t comm_; - std::unique_ptr dev_ctx_; -}; - -CNCLComm* CNCLCommContext::CreateComm( - cnclCliqueId* cncl_id, int nranks, int rank, int dev_id, int ring_id) { - PADDLE_ENFORCE_NOT_NULL(cncl_id, - platform::errors::InvalidArgument( - "The cncl unique id should not be null.")); - PADDLE_ENFORCE_GT( - nranks, - 1, - platform::errors::InvalidArgument( - "Expected nranks > 1. But received nranks is %d.", nranks)); - PADDLE_ENFORCE_GE(rank, - 0, - platform::errors::InvalidArgument( - "Expected rank >= 0. But received rank is %d.", rank)); - PADDLE_ENFORCE_LT( - rank, - nranks, - platform::errors::InvalidArgument( - "Expected rank < nranks. But received rank is %d, nranks is %d.", - rank, - nranks)); - PADDLE_ENFORCE_GE( - dev_id, - 0, - platform::errors::InvalidArgument( - "Expected dev_id >= 0. But received dev_id is %d.", dev_id)); - - cnclComm_t comm; - int dev_list[] = {dev_id}; - int rank_list[] = {rank}; - SetMLUDeviceId(dev_id); - PADDLE_ENFORCE_MLU_SUCCESS( - cnclInitComms(&comm, 1, dev_list, rank_list, nranks, cncl_id)); - - auto* comm_wrapper = AssignCNCLComm(comm, nranks, rank, dev_id, ring_id); - - VLOG(1) << "cncl communicator of rank " << rank << " in ring " << ring_id - << " has been created on device " << dev_id; - - std::call_once(once_flag_, []() { - std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); }); - }); - - return comm_wrapper; -} - -void CNCLCommContext::CreateAllCNCLComms(const std::vector& dev_ids, - int ring_id) { - PADDLE_ENFORCE_GT( - dev_ids.size(), - 0, - platform::errors::InvalidArgument("Expected the size of dev_ids > 0. But " - "received the size of dev_ids is %d.", - dev_ids.size())); - - const int kDevices = dev_ids.size(); - cnclComm_t comms[kDevices]; - int* rank_list = new int[kDevices]; - for (int i = 0; i < kDevices; i++) { - rank_list[i] = i; - } - cnclCliqueId clique_id; - PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&clique_id)); - PADDLE_ENFORCE_MLU_SUCCESS(cnclInitComms(comms, - dev_ids.size(), - dev_ids.data(), - rank_list, - dev_ids.size(), - &clique_id)); - - PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), - 0, - platform::errors::InvalidArgument( - "Expected comm_map_.count(ring_id) = 0. But received " - "comm_map_.count(ring_id) is %d.", - comm_map_.count(ring_id))); - for (size_t i = 0; i < dev_ids.size(); ++i) { - AssignCNCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id); - VLOG(1) << "cncl communicator of rank " << i << " in ring " << ring_id - << " has been created on device " << dev_ids[i]; - } - - std::call_once(once_flag_, []() { - std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); }); - }); - delete[] rank_list; -} - -CNCLComm* CNCLCommContext::AssignCNCLComm( - cnclComm_t comm, int nranks, int rank, int dev_id, int ring_id) { - std::unique_ptr dev_ctx( - new MLUDeviceContext(MLUPlace(dev_id))); - - CNCLCommImpl* c = new CNCLCommImpl; - c->set_ring_id(ring_id); - c->set_nranks(nranks); - c->set_rank(rank); - c->set_comm(comm); - c->set_dev_ctx(std::move(dev_ctx)); - - comm_map_mutex_.lock(); - if (comm_map_.count(ring_id) == 0) { - comm_map_.emplace(ring_id, std::map>()); - } - auto& dev2comm = comm_map_[ring_id]; - - dev2comm.emplace(dev_id, std::unique_ptr(c)); - comm_map_mutex_.unlock(); - - if (ring_id == 0) { - auto* dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get( - platform::MLUPlace(dev_id))); - dev_ctx->set_cncl_comm(comm); - } - - return comm_map_[ring_id][dev_id].get(); -} - -void CNCLCommContext::ReleaseCNCLComms() { - for (auto& p : comm_map_) { - for (auto& q : p.second) { - q.second.reset(); - } - } -} - -} // namespace platform -} // namespace paddle -#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_info.cc b/paddle/fluid/platform/device/mlu/mlu_info.cc deleted file mode 100644 index f4df71d9847..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_info.cc +++ /dev/null @@ -1,464 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/device/mlu/mlu_info.h" - -#include -#include - -#include "gflags/gflags.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/device/mlu/enforce.h" -#include "paddle/fluid/platform/lock_guard_ptr.h" -#include "paddle/fluid/platform/monitor.h" -#include "paddle/fluid/string/split.h" - -DECLARE_double(fraction_of_gpu_memory_to_use); -DECLARE_uint64(initial_gpu_memory_in_mb); -DECLARE_uint64(reallocate_gpu_memory_in_mb); -DECLARE_uint64(gpu_memory_limit_mb); - -constexpr static float fraction_reserve_mlu_memory = 0.05f; - -PADDLE_DEFINE_EXPORTED_string( - selected_mlus, - "", - "A list of device ids separated by comma, like: 0,1,2,3. " - "This option is useful when doing multi process training and " - "each process have only one device (MLU). If you want to use " - "all visible devices, set this to empty string. NOTE: the " - "reason of doing this is that we want to use P2P communication" - "between MLU devices, use MLU_VISIBLE_DEVICES can only use" - "share-memory only."); - -USE_MLU_MEM_STAT; -namespace paddle { -namespace platform { - -static int GetMLUDeviceCountImpl() { - int x, y, z; - // When cnrtDriverGetVersion is executed, the device is initialized, - // no longer needs to call cnrtInit(). - cnrtStatus stat = cnrtDriverGetVersion(&x, &y, &z); - if (stat != cnrtSuccess) { - VLOG(2) << "MLU Driver Version can't be detected. No MLU driver!"; - return 0; - } - - const auto *mlu_visible_devices = std::getenv("MLU_VISIBLE_DEVICES"); - if (mlu_visible_devices != nullptr) { - std::string mlu_visible_devices_str(mlu_visible_devices); - if (std::all_of(mlu_visible_devices_str.begin(), - mlu_visible_devices_str.end(), - [](char ch) { return ch == ' '; })) { - VLOG(2) << "MLU_VISIBLE_DEVICES is set to be " - "empty. No MLU detected."; - return 0; - } - } - - int count; - PADDLE_ENFORCE_MLU_SUCCESS(cnDeviceGetCount(&count)); - return count; -} - -int GetMLUDeviceCount() { - static auto dev_cnt = GetMLUDeviceCountImpl(); - return dev_cnt; -} - -std::vector GetMLUSelectedDevices() { - // use user specified MLUs in single-node multi-process mode. - std::vector devices; - if (!FLAGS_selected_mlus.empty()) { - auto devices_str = paddle::string::Split(FLAGS_selected_mlus, ','); - for (auto id : devices_str) { - devices.push_back(atoi(id.c_str())); - } - } else { - int count = GetMLUDeviceCount(); - for (int i = 0; i < count; ++i) { - devices.push_back(i); - } - } - return devices; -} - -void CheckDeviceId(int id) { - PADDLE_ENFORCE_LT(id, - GetMLUDeviceCount(), - platform::errors::InvalidArgument( - "Device id must be less than MLU count, " - "but received id is: %d. MLU count is: %d.", - id, - GetMLUDeviceCount())); -} - -int GetMLUDriverVersion(int id) { - CheckDeviceId(id); - int x, y, z; - PADDLE_ENFORCE_MLU_SUCCESS(cnrtDriverGetVersion(&x, &y, &z)); - return x * 10000 + y * 100 + z; -} - -int GetMLURuntimeVersion(int id) { - CheckDeviceId(id); - int x, y, z; - PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetLibVersion(&x, &y, &z)); - return x * 10000 + y * 100 + z; -} - -int GetMLUCnnlVersion(int id) { - CheckDeviceId(id); - int x, y, z; - cnnlGetLibVersion(&x, &y, &z); - return x * 10000 + y * 100 + z; -} - -int GetMLUOpVersion(int id) { - CheckDeviceId(id); - int x, y, z; - mluOpGetLibVersion(&x, &y, &z); - return x * 10000 + y * 100 + z; -} - -int GetMLUCurrentDeviceId() { - int device_id; - PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); - return device_id; -} - -void SetMLUDeviceId(int id) { - CheckDeviceId(id); - PADDLE_RETRY_MLU_SUCCESS(cnrtSetDevice(id)); -} - -void GetMLUDeviceHandle(int device_ordinal, mluDeviceHandle *device) { - cnStatus res = cnDeviceGet(device, device_ordinal); - if (res != CN_SUCCESS) { - VLOG(2) << "failed to get handle of MLU Device."; - } - PADDLE_ENFORCE_MLU_SUCCESS(res); -} - -int GetMLUComputeCapability(int id) { - CheckDeviceId(id); - mluDeviceHandle device; - GetMLUDeviceHandle(id, &device); - - int major, minor; - cnStatus major_stat = cnDeviceGetAttribute( - &major, CN_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); - cnStatus minor_stat = cnDeviceGetAttribute( - &minor, CN_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - PADDLE_ENFORCE_MLU_SUCCESS(major_stat); - PADDLE_ENFORCE_MLU_SUCCESS(minor_stat); - - return major * 10 + minor; -} - -void MLUMemoryUsage(size_t *available, size_t *total) { - size_t actual_available, actual_total; - RecordedMLUMemGetInfo(available, - total, - &actual_available, - &actual_total, - platform::GetMLUCurrentDeviceId()); -} - -size_t MLUAvailableMemToAlloc() { - size_t total = 0; - size_t available = 0; - MLUMemoryUsage(&available, &total); - size_t reserving = - static_cast(fraction_reserve_mlu_memory * available); - // If available size is less than minimum chunk size, no usable memory exists - size_t available_to_alloc = available - reserving; - size_t min_chunk_size = MLUMinChunkSize(); - if (available_to_alloc < min_chunk_size) { - available_to_alloc = 0; - } - VLOG(10) << "MLU usage " << ((total - available) >> 20) << "M/" - << (total >> 20) << "M, " << (available_to_alloc >> 20) - << "M available to allocate"; - return available_to_alloc; -} - -size_t MLUMaxAllocSize() { - return std::max(MLUInitAllocSize(), MLUReallocSize()); -} - -static size_t MLUAllocSize(bool realloc) { - size_t available_to_alloc = MLUAvailableMemToAlloc(); - PADDLE_ENFORCE_GT( - available_to_alloc, - 0, - platform::errors::ResourceExhausted("Not enough available MLU memory.")); - // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be - // allocated by fraction - size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb - : FLAGS_initial_gpu_memory_in_mb; - size_t alloc_bytes = - (flag_mb > 0ul - ? flag_mb << 20 - : available_to_alloc * FLAGS_fraction_of_gpu_memory_to_use); - PADDLE_ENFORCE_GE( - available_to_alloc, - alloc_bytes, - platform::errors::ResourceExhausted("Not enough available MLU memory.")); - VLOG(10) << "Alloc size is " << (alloc_bytes >> 20) - << " MiB, is it Re-alloc: " << realloc; - return alloc_bytes; -} - -size_t MLUInitAllocSize() { return MLUAllocSize(/* realloc = */ false); } - -size_t MLUReallocSize() { return MLUAllocSize(/* realloc = */ true); } - -size_t MLUMaxChunkSize() { - size_t max_chunk_size = MLUMaxAllocSize(); - VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M"; - return max_chunk_size; -} - -void MLUMemcpyD2HAsync(void *dst, - const void *src, - size_t num, - mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync( - dst, const_cast(src), num, stream, cnrtMemcpyDevToHost)); -} - -void MLUMemcpyD2HSync(void *dst, const void *src, size_t num) { - PADDLE_ENFORCE_MLU_SUCCESS( - cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyDevToHost)); -} - -void MLUMemcpyH2DAsync(void *dst, - const void *src, - size_t num, - mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync( - dst, const_cast(src), num, stream, cnrtMemcpyHostToDev)); -} -void MLUMemcpyH2DSync(void *dst, const void *src, size_t num) { - PADDLE_ENFORCE_MLU_SUCCESS( - cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyHostToDev)); -} - -void MLUMemcpyD2DAsync(void *dst, - const void *src, - size_t num, - mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync( - dst, const_cast(src), num, stream, cnrtMemcpyDevToDev)); -} -void MLUMemcpyD2DSync(void *dst, const void *src, size_t num) { - PADDLE_ENFORCE_MLU_SUCCESS( - cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyDevToDev)); -} - -void MLUMemcpyPeerAsync(void *dst, - int dst_device, - const void *src, - int src_device, - size_t num, - mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyPeerAsync( - dst, dst_device, const_cast(src), src_device, num, stream)); -} - -void MLUMemcpyPeerSync( - void *dst, int dst_device, const void *src, int src_device, size_t num) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyPeer( - dst, dst_device, const_cast(src), src_device, num)); -} - -void MLUMemsetAsync(void *dst, int value, size_t count, mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemsetAsync(dst, value, count, stream)); -} - -void MLUStreamSync(mluStream stream) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream)); -} - -static void RaiseNonOutOfMemoryError(cnrtStatus *status) { - if (*status == cnrtErrorNoMem) { - *status = cnrtSuccess; - } - PADDLE_ENFORCE_MLU_SUCCESS(*status); - - *status = cnrtGetLastError(); - if (*status == cnrtErrorNoMem) { - *status = cnrtSuccess; - } - PADDLE_ENFORCE_MLU_SUCCESS(*status); -} - -class RecordedMLUMallocHelper { - private: - explicit RecordedMLUMallocHelper(int dev_id, uint64_t limit_size = 0) - : dev_id_(dev_id), limit_size_(limit_size) { - if (NeedRecord()) { - mtx_.reset(new std::mutex()); - } - } - - DISABLE_COPY_AND_ASSIGN(RecordedMLUMallocHelper); - - public: - static RecordedMLUMallocHelper *Instance(int dev_id) { - std::call_once(once_flag_, [] { - int dev_cnt = GetMLUDeviceCount(); - instances_.reserve(dev_cnt); - for (int i = 0; i < dev_cnt; ++i) { - instances_.emplace_back( - new RecordedMLUMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20)); - } - }); - - PADDLE_ENFORCE_GE( - dev_id, - 0, - platform::errors::OutOfRange( - "Device id must be not less than 0, but got %d.", dev_id)); - PADDLE_ENFORCE_LT( - dev_id, - instances_.size(), - platform::errors::OutOfRange("Device id %d exceeds mlu card number %d.", - dev_id, - instances_.size())); - return instances_[dev_id].get(); - } - - /** - * Try to allocate `size` mlu memory. Only cnrtErrorNoMem - * or cnrtSuccess would be returned, and the cnrtGetLastError() flag - * would be clear. - */ - cnrtStatus Malloc(void **ptr, size_t size) { - LockGuardPtr lock(mtx_); - if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) { - return cnrtErrorNoMem; - } - - MLUDeviceGuard guard(dev_id_); - auto result = cnrtMalloc(ptr, size); - if (result == cnrtSuccess) { - cur_size_.fetch_add(size); - STAT_INT_ADD("STAT_mlu" + std::to_string(dev_id_) + "_mem_size", size); - return cnrtSuccess; - } else { - RaiseNonOutOfMemoryError(&result); - // Non out of memory error would be raised inside - // RaiseNonOutOfMemoryError. - // Therefore, we can return cnrtErrorNoMem directly here. - return cnrtErrorNoMem; - } - } - - /** - * Free mlu memory. Usually, free is not allowed to raise error. - * If it does raise error, the process should be crashed. - */ - void Free(void *ptr, size_t size) { - MLUDeviceGuard guard(dev_id_); - auto err = cnrtFree(ptr); - PADDLE_ENFORCE_MLU_SUCCESS(err); - if (NeedRecord()) { - cur_size_.fetch_sub(size); - } - STAT_INT_SUB("STAT_mlu" + std::to_string(dev_id_) + "_mem_size", size); - } - - bool GetMemInfo(size_t *avail, - size_t *total, - size_t *actual_avail, - size_t *actual_total) { - { - MLUDeviceGuard guard(dev_id_); - auto result = cnrtMemGetInfo(actual_avail, actual_total); - if (result != cnrtSuccess) { - *actual_avail = 0; - } - RaiseNonOutOfMemoryError(&result); - } - - if (NeedRecord()) { - std::lock_guard guard(*mtx_); - *avail = std::min(*actual_avail, limit_size_ - cur_size_.load()); - *total = std::min(*actual_total, limit_size_); - return *total < *actual_total; - } else { - *avail = *actual_avail; - *total = *actual_total; - return false; - } - } - - inline bool NeedRecord() const { return limit_size_ != 0; } - - uint64_t RecordedSize() const { return cur_size_.load(); } - - uint64_t LimitSize() const { return limit_size_; } - - private: - const int dev_id_; - const uint64_t limit_size_; - std::atomic cur_size_{0}; - - mutable std::unique_ptr mtx_; - - static std::once_flag once_flag_; - static std::vector> instances_; -}; // NOLINT - -std::once_flag RecordedMLUMallocHelper::once_flag_; -std::vector> - RecordedMLUMallocHelper::instances_; - -cnrtStatus RecordedMLUMalloc(void **ptr, size_t size, int dev_id) { - return RecordedMLUMallocHelper::Instance(dev_id)->Malloc(ptr, size); -} - -void RecordedMLUFree(void *p, size_t size, int dev_id) { - return RecordedMLUMallocHelper::Instance(dev_id)->Free(p, size); -} - -bool RecordedMLUMemGetInfo(size_t *avail, - size_t *total, - size_t *actual_avail, - size_t *actual_total, - int dev_id) { - return RecordedMLUMallocHelper::Instance(dev_id)->GetMemInfo( - avail, total, actual_avail, actual_total); -} - -uint64_t RecordedMLUMallocSize(int dev_id) { - return RecordedMLUMallocHelper::Instance(dev_id)->RecordedSize(); -} - -bool IsMLUMallocRecorded(int dev_id) { - return RecordedMLUMallocHelper::Instance(dev_id)->NeedRecord(); -} - -void EmptyCache(void) { - std::vector devices = GetMLUSelectedDevices(); - for (auto device : devices) { - memory::Release(MLUPlace(device)); - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h deleted file mode 100644 index 435e71cf105..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifdef PADDLE_WITH_MLU -#include -#include -#include -#include -#include -#include -#ifdef PADDLE_WITH_CNCL -#include -#endif -#include -#include "paddle/phi/backends/mlu/mlu_info.h" - -namespace paddle { - -using cnStatus = CNresult; -using cnrtStatus = cnrtRet_t; -using cnnlStatus = cnnlStatus_t; -using mluOpStatus = mluOpStatus_t; -#ifdef PADDLE_WITH_CNCL -using cnclStatus = cnclResult_t; -#endif -using mluStream = cnrtQueue_t; -using mluCnnlHandle = cnnlHandle_t; -using mluOpHandle = mluOpHandle_t; -using mluEventHandle = cnrtNotifier_t; -using mluDeviceHandle = CNdev; - -namespace platform { - -//! Get the driver version of the ith MLU. -int GetMLUDriverVersion(int id); - -//! Get the runtime version of the ith MLU. -int GetMLURuntimeVersion(int id); - -//! Get the cnnl version of the ith MLU. -int GetMLUCnnlVersion(int id); - -//! Get the mluOp version of the ith MLU. -int GetMLUOpVersion(int id); - -//! Get the total number of MLU devices in system. -int GetMLUDeviceCount(); - -//! Get a list of device ids from environment variable or use all. -std::vector GetMLUSelectedDevices(); - -//! Get the current MLU device id in system. -int GetMLUCurrentDeviceId(); - -//! Set the MLU device id for next execution. -void SetMLUDeviceId(int device_id); - -//! Get a handle of device ids. -void GetMLUDeviceHandle(int device_ordinal, mluDeviceHandle* device); - -//! Get the compute capability of the ith MLU (format: major * 10 + minor) -int GetMLUComputeCapability(int id); - -//! Get the memory usage of current MLU device. -void MLUMemoryUsage(size_t* available, size_t* total); - -//! Get the available memory to allocate, which is the size of available mlu -//! minus reserving. -size_t MLUAvailableMemToAlloc(); - -//! Get the maximum allocation size of current MLU device. -size_t MLUMaxAllocSize(); - -//! Get the initial allocation size of current MLU device. -size_t MLUInitAllocSize(); - -//! Get the re-allocation size of current MLU device. -size_t MLUReallocSize(); - -using phi::backends::mlu::MLUMinChunkSize; - -//! Get the maximum chunk size for MLU buddy allocator. -size_t MLUMaxChunkSize(); - -//! Copy memory from address device to host asynchronously. -void MLUMemcpyD2HAsync(void* dst, - const void* src, - size_t num, - mluStream stream); - -//! Copy memory from address device to host synchronously. -void MLUMemcpyD2HSync(void* dst, const void* src, size_t num); - -//! Copy memory from address host to device asynchronously. -void MLUMemcpyH2DAsync(void* dst, - const void* src, - size_t num, - mluStream stream); - -//! Copy memory from address host to device synchronously. -void MLUMemcpyH2DSync(void* dst, const void* src, size_t num); - -//! Copy memory from address device to device asynchronously in a single device. -void MLUMemcpyD2DAsync(void* dst, - const void* src, - size_t num, - mluStream stream); - -//! Copy memory from address device to device synchronously in a single device. -void MLUMemcpyD2DSync(void* dst, const void* src, size_t num); - -//! Copy memory from one device to another device asynchronously. -void MLUMemcpyPeerAsync(void* dst, - int dst_place, - const void* src, - int src_place, - size_t num, - mluStream stream); - -//! Copy memory from one device to another device synchronously. -void MLUMemcpyPeerSync( - void* dst, int dst_place, const void* src, int src_place, size_t num); - -//! Set memory dst with value count size asynchronously -void MLUMemsetAsync(void* dst, int value, size_t count, mluStream stream); - -//! Blocks until stream has completed all operations. -void MLUStreamSync(mluStream stream); - -//! MLUMalloc with recorded info -cnrtStatus RecordedMLUMalloc(void** ptr, size_t size, int dev_id); - -//! MLUFree with recorded info -void RecordedMLUFree(void* p, size_t size, int dev_id); - -//! Get available and total mlu memory with considering limitation -bool RecordedMLUMemGetInfo(size_t* avail, - size_t* total, - size_t* actual_avail, - size_t* actual_total, - int dev_id); - -//! Get recorded mluMalloc size. If record is disabled, return 0. -uint64_t RecordedMLUMallocSize(int dev_id); - -bool IsMLUMallocRecorded(int dev_id); - -//! Empty idle cached memory held by the allocator. -void EmptyCache(void); - -class MLUDeviceGuard { - public: - explicit inline MLUDeviceGuard(int dev_id) { - int prev_id = platform::GetMLUCurrentDeviceId(); - if (prev_id != dev_id) { - prev_id_ = prev_id; - platform::SetMLUDeviceId(dev_id); - } - } - - inline ~MLUDeviceGuard() { - if (prev_id_ != -1) { - platform::SetMLUDeviceId(prev_id_); - } - } - - MLUDeviceGuard(const MLUDeviceGuard& o) = delete; - MLUDeviceGuard& operator=(const MLUDeviceGuard& o) = delete; - - private: - int prev_id_{-1}; -}; - -} // namespace platform -} // namespace paddle - -#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc b/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc deleted file mode 100644 index fffd06f63af..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_resource_pool.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if defined(PADDLE_WITH_MLU) -#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h" - -namespace paddle { -namespace platform { - -MluStreamResourcePool::MluStreamResourcePool() { - int dev_cnt = platform::GetMLUDeviceCount(); - pool_.reserve(dev_cnt); - for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { - auto creator = [dev_idx] { - platform::SetMLUDeviceId(dev_idx); - mluStream stream; - cnrtQueueCreate(&stream); - return stream; - }; - - auto deleter = [dev_idx](mluStream stream) { - platform::SetMLUDeviceId(dev_idx); - cnrtQueueDestroy(stream); - }; - - pool_.emplace_back(ResourcePool::Create(creator, deleter)); - } -} - -MluStreamResourcePool& MluStreamResourcePool::Instance() { - static MluStreamResourcePool pool; - return pool; -} - -std::shared_ptr MluStreamResourcePool::New(int dev_idx) { - PADDLE_ENFORCE_GE( - dev_idx, - 0, - platform::errors::InvalidArgument( - "The dev_idx should be not less than 0, but got %d.", dev_idx)); - PADDLE_ENFORCE_LT( - dev_idx, - pool_.size(), - platform::errors::OutOfRange( - "The dev_idx should be less than device count %d, but got %d.", - pool_.size(), - dev_idx)); - return pool_[dev_idx]->New(); -} - -MluEventResourcePool::MluEventResourcePool() { - int dev_cnt = platform::GetMLUDeviceCount(); - pool_.reserve(dev_cnt); - for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { - auto creator = [dev_idx] { - platform::SetMLUDeviceId(dev_idx); - mluEventHandle event; - cnrtNotifierCreate(&event); - return event; - }; - - auto deleter = [dev_idx](mluEventHandle event) { - platform::SetMLUDeviceId(dev_idx); - cnrtNotifierDestroy(event); - }; - - pool_.emplace_back(ResourcePool::Create(creator, deleter)); - } -} - -MluEventResourcePool& MluEventResourcePool::Instance() { - static MluEventResourcePool pool; - return pool; -} - -std::shared_ptr MluEventResourcePool::New(int dev_idx) { - PADDLE_ENFORCE_GE( - dev_idx, - 0, - platform::errors::InvalidArgument( - "The dev_idx should be not less than 0, but got %d.", dev_idx)); - PADDLE_ENFORCE_LT( - dev_idx, - pool_.size(), - platform::errors::OutOfRange( - "The dev_idx should be less than device count %d, but got %d.", - pool_.size(), - dev_idx)); - return pool_[dev_idx]->New(); -} - -} // namespace platform -} // namespace paddle -#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_resource_pool.h b/paddle/fluid/platform/device/mlu/mlu_resource_pool.h deleted file mode 100644 index b0e2af7f024..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_resource_pool.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#if defined(PADDLE_WITH_MLU) -#include -#include -#include - -#include "paddle/fluid/platform/device/mlu/mlu_info.h" -#include "paddle/fluid/platform/resource_pool.h" - -namespace paddle { -namespace platform { - -using MluStreamObject = std::remove_pointer::type; -using MluEventObject = std::remove_pointer::type; - -class MluStreamResourcePool { - public: - std::shared_ptr New(int dev_idx); - - static MluStreamResourcePool &Instance(); - - private: - MluStreamResourcePool(); - - DISABLE_COPY_AND_ASSIGN(MluStreamResourcePool); - - private: - std::vector>> pool_; -}; - -class MluEventResourcePool { - public: - std::shared_ptr New(int dev_idx); - - static MluEventResourcePool &Instance(); - - private: - MluEventResourcePool(); - - DISABLE_COPY_AND_ASSIGN(MluEventResourcePool); - - private: - std::vector>> pool_; -}; - -} // namespace platform -} // namespace paddle - -#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_stream.cc b/paddle/fluid/platform/device/mlu/mlu_stream.cc deleted file mode 100644 index 9e6aac2983b..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_stream.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/device/mlu/mlu_stream.h" - -#include "paddle/fluid/platform/device/mlu/device_context.h" - -namespace paddle { -namespace platform { -namespace stream { - -bool MLUStream::Init(const MLUPlace& place, const int priority) { - PADDLE_ENFORCE_EQ(is_mlu_place(place), - true, - platform::errors::InvalidArgument( - "MLU stream must be created using mlu place.")); - place_ = place; - MLUDeviceGuard guard(place_.device); - PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueCreate(&stream_)); - callback_manager_.reset(new StreamCallbackManager(stream_)); - VLOG(3) << "MLUStream Init stream: " << stream_; - return true; -} - -void MLUStream::Destroy() { - MLUDeviceGuard guard(place_.device); - Wait(); - WaitCallback(); - if (stream_) { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueDestroy(stream_)); - } - stream_ = nullptr; -} - -void MLUStream::Wait() const { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); -} - -MLUStream* get_current_mlu_stream(int deviceId) { -#ifdef PADDLE_WITH_MLU - if (deviceId == -1) { - deviceId = platform::GetMLUCurrentDeviceId(); - } - auto& pool = platform::DeviceContextPool::Instance(); - platform::Place device = MLUPlace(deviceId); - auto stream = static_cast(pool.Get(device)) - ->context() - ->Stream() - .get(); - return stream; -#else - PADDLE_THROW(platform::errors::Unavailable( - "Paddle is not compiled with MLU. Cannot visit mlu current stream.")); - return nullptr; -#endif -} - -MLUStream* set_current_mlu_stream(MLUStream* stream) { -#ifdef PADDLE_WITH_MLU - auto& device = stream->GetPlace(); - auto& pool = platform::DeviceContextPool::Instance(); - return static_cast(pool.Get(device)) - ->context() - ->SetStream(stream); -#else - PADDLE_THROW(platform::errors::Unavailable( - "Paddle is not compiled with MLU. Cannot visit mlu current stream.")); - return nullptr; -#endif -} -} // namespace stream -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/mlu_stream.h b/paddle/fluid/platform/device/mlu/mlu_stream.h deleted file mode 100644 index b20949f3bfe..00000000000 --- a/paddle/fluid/platform/device/mlu/mlu_stream.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/platform/device/mlu/enforce.h" -#include "paddle/fluid/platform/device/mlu/mlu_info.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/platform/stream_callback_manager.h" - -namespace paddle { -namespace platform { -namespace stream { - -#ifdef PADDLE_WITH_MLU -class MLUStream final { - public: - MLUStream() = default; - explicit MLUStream(const MLUPlace& place, const int priority = 0) { - Init(place, priority); - } - virtual ~MLUStream() { Destroy(); } - - bool Init(const MLUPlace& place, const int priority = 0); - - template - void AddCallback(Callback&& callback) const { - callback_manager_->AddCallback(callback); - } - - template - void RecordEvent(mluEventHandle event, Callback callback) const { - callback(); - PADDLE_ENFORCE_MLU_SUCCESS(cnPlaceNotifier(event, stream_)); - } - - void RecordEvent(mluEventHandle event) const { - PADDLE_ENFORCE_MLU_SUCCESS(cnPlaceNotifier(event, stream_)); - } - - void WaitEvent(mluEventHandle event) const { - PADDLE_ENFORCE_MLU_SUCCESS(cnWaitNotifier(event)); - } - - void Wait() const; - void WaitCallback() const { callback_manager_->Wait(); } - - const mluStream& raw_stream() const { return stream_; } - - void Destroy(); - - bool Query() const { - cnrtStatus stat = cnrtQueueQuery(stream_); - if (stat == cnrtSuccess) { - return true; - } - if (stat == cnrtErrorNotReady) { - return false; - } - PADDLE_ENFORCE_MLU_SUCCESS(stat); - return false; - } - - void Synchronize() const { - PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); - } - - const MLUPlace& GetPlace() const { return place_; } - - private: - MLUPlace place_; - mluStream stream_{nullptr}; - int priority_{0}; - std::unique_ptr> callback_manager_; - - DISABLE_COPY_AND_ASSIGN(MLUStream); -}; - -MLUStream* get_current_mlu_stream(int deviceId); -MLUStream* set_current_mlu_stream(MLUStream* stream); - -#endif - -} // namespace stream -} // namespace platform -} // namespace paddle -- GitLab