diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c650b6ed7a5f920723f117fc08002cc1d783588..f122dbb9cfc0951e79070523f7944881d82aaa10 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) +option(WITH_MLU "Compile PaddlePaddle with CAMBRICON MLU" OFF) option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF) option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF) option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF) @@ -64,6 +65,9 @@ endif() if (WITH_GPU AND WITH_ROCM) message(FATAL_ERROR "Error when compile CUDA and ROCM at the same time") endif() +if (WITH_GPU AND WITH_MLU) + message(FATAL_ERROR "Error when compile GPU and MLU at the same time") +endif() if(WITH_GPU AND NOT APPLE) enable_language(CUDA) @@ -302,6 +306,10 @@ if(WITH_GPU) endif() endif() +if(WITH_MLU) + include(neuware) +endif() + if(WITH_ROCM) include(hip) include(miopen) # set miopen libraries, must before configure diff --git a/cmake/configure.cmake b/cmake/configure.cmake index a77f9f72ca6adeb4173f67662da687b7eeef4cf7..32ba2ff3ac627304c9d3095cca58a4a071d6b5b7 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -102,6 +102,11 @@ if(WITH_IPU) add_definitions(-DPADDLE_WITH_IPU) endif() +if(WITH_MLU) + message(STATUS "Compile with MLU!") + add_definitions(-DPADDLE_WITH_MLU) +endif() + if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DEIGEN_USE_GPU) diff --git a/cmake/external/concurrentqueue.cmake b/cmake/external/concurrentqueue.cmake new file mode 100644 index 0000000000000000000000000000000000000000..9e4331ae6fdea01f4aad4a85c7b84ac3ff317098 --- /dev/null +++ b/cmake/external/concurrentqueue.cmake @@ -0,0 +1,42 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + +set(CONCURRENTQUEUE_PROJECT "extern_concurrentqueue") +set(CONCURRENTQUEUE_VER "v1.0.3") +SET(CONCURRENTQUEUE_URL_MD5 118e5bb661b567634647312991e10222) +set(CONCURRENTQUEUE_PREFIX_URL "https://github.com/cameron314/concurrentqueue/archive/refs/tags") +set(CONCURRENTQUEUE_URL "${CONCURRENTQUEUE_PREFIX_URL}/${CONCURRENTQUEUE_VER}.tar.gz") + +MESSAGE(STATUS "CONCURRENTQUEUE_VERSION: ${CONCURRENTQUEUE_VER}, CONCURRENTQUEUE_URL: ${CONCURRENTQUEUE_URL}") + +set(CONCURRENTQUEUE_PREFIX_DIR ${THIRD_PARTY_PATH}/concurrentqueue) +set(CONCURRENTQUEUE_SOURCE_DIR ${THIRD_PARTY_PATH}/concurrentqueue/src/) +set(CONCURRENTQUEUE_INCLUDE_DIR "${CONCURRENTQUEUE_SOURCE_DIR}/extern_concurrentqueue") + +ExternalProject_Add( + ${CONCURRENTQUEUE_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${CONCURRENTQUEUE_URL} + URL_MD5 ${CONCURRENTQUEUE_URL_MD5} + PREFIX ${CONCURRENTQUEUE_PREFIX_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + UPDATE_COMMAND "" + ) + +include_directories(${CONCURRENTQUEUE_INCLUDE_DIR}) diff --git a/cmake/neuware.cmake b/cmake/neuware.cmake new file mode 100644 index 0000000000000000000000000000000000000000..7219f5f72595bde56413080584b7b6b494e70905 --- /dev/null +++ b/cmake/neuware.cmake @@ -0,0 +1,22 @@ +if(NOT WITH_MLU) + return() +endif() + +if(NOT ENV{NEUWARE_HOME}) + set(NEUWARE_HOME "/usr/local/neuware") +else() + set(NEUWARE_HOME $ENV{NEUWARE_HOME}) +endif() +message(STATUS "NEUWARE_HOME: " ${NEUWARE_HOME}) + +set(NEUWARE_INCLUDE_DIR ${NEUWARE_HOME}/include) +set(NEUWARE_LIB_DIR ${NEUWARE_HOME}/lib64) + +INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR}) + +set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so) +set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) +set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) + +generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") +TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index a537719cc75829c3fd756b3cbb74c43753ed46ea..673b33900d67356ec4ae27d9b327c2cd2d282965 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -12,6 +12,7 @@ function(op_library TARGET) set(hip_cc_srcs) set(xpu_cc_srcs) set(npu_cc_srcs) + set(mlu_cc_srcs) set(cudnn_cu_cc_srcs) set(miopen_cu_cc_srcs) set(cudnn_cu_srcs) @@ -24,6 +25,10 @@ function(op_library TARGET) if (WITH_ASCEND_CL) set(op_common_deps ${op_common_deps} npu_op_runner) endif() + if (WITH_MLU) + set(op_common_deps ${op_common_deps} mlu_baseop) + endif() + # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build. set(options UNITY) set(oneValueArgs "") @@ -98,6 +103,12 @@ function(op_library TARGET) list(APPEND npu_cc_srcs ${NPU_FILE}.cc) endif() endif() + if(WITH_MLU) + string(REPLACE "_op" "_op_mlu" MLU_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MLU_FILE}.cc) + list(APPEND mlu_cc_srcs ${MLU_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$") @@ -122,6 +133,8 @@ function(op_library TARGET) list(APPEND xpu_cc_srcs ${src}) elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") list(APPEND npu_cc_srcs ${src}) + elseif(WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$") + list(APPEND mlu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) else() @@ -196,7 +209,7 @@ function(op_library TARGET) # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`. if(WITH_UNITY_BUILD AND op_library_UNITY) # Combine the cc source files. - compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs}) + compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs}) if(TARGET ${UNITY_TARGET}) # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`. target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources}) @@ -207,7 +220,7 @@ function(op_library TARGET) # Add alias library to handle dependencies. add_library(${TARGET} ALIAS ${UNITY_TARGET}) else() - cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS} + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() endif() @@ -262,8 +275,10 @@ function(op_library TARGET) list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) list(LENGTH npu_cc_srcs npu_cc_srcs_len) + list(LENGTH mlu_cc_srcs mlu_cc_srcs_len) if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND - ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0) + ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND + ${npu_cc_srcs_len} EQUAL 0 AND ${mlu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -322,6 +337,24 @@ function(op_library TARGET) endif() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n") endif() + if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0) + file(READ ${ORIGINAL_TARGET}_mlu.cc TARGET_MLU_CONTENT) + # It is different from the logic above, becareful + string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\(.*" multi_mlu_register "${TARGET_MLU_CONTENT}") + # [ \t\r\n]* is used for blank characters + string(REGEX MATCH "REGISTER_OP_MLU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_mlu_register "${multi_mlu_register}") + + if (one_mlu_register STREQUAL "") + string(REPLACE "_op" "" MLU_TARGET "${TARGET}") + else () + string(REPLACE "REGISTER_OP_MLU_KERNEL(" "" MLU_TARGET "${one_mlu_register}") + string(REPLACE "," "" MLU_TARGET "${MLU_TARGET}") + # [ \t\r\n]+ is used for blank characters. + # Here we use '+' instead of '*' since it is a REPLACE operation. + string(REGEX REPLACE "[ \t\r\n]+" "" MLU_TARGET "${MLU_TARGET}") + endif() + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MLU_TARGET}, MLU);\n") + endif() # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) @@ -369,11 +402,11 @@ function(register_operators) set(multiValueArgs EXCLUDES DEPS) cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") string(REPLACE "_mkldnn" "" OPS "${OPS}") string(REPLACE "_xpu" "" OPS "${OPS}") string(REPLACE "_npu" "" OPS "${OPS}") + string(REPLACE "_mlu" "" OPS "${OPS}") string(REPLACE ".cc" "" OPS "${OPS}") list(REMOVE_DUPLICATES OPS) list(LENGTH register_operators_DEPS register_operators_DEPS_len) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 5c6147573797e49024618e433a743450d34761c4..38e7d2039af8ad3c129bc8d7adf9e893c6d33fdb 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -278,6 +278,11 @@ if(WITH_XPU) list(APPEND third_party_deps extern_xpu) endif(WITH_XPU) +if(WITH_MLU) + include(external/concurrentqueue) # download, build, install concurrentqueue + list(APPEND third_party_deps extern_concurrentqueue) +endif(WITH_MLU) + if(WITH_PSLIB) include(external/pslib) # download, build, install pslib list(APPEND third_party_deps extern_pslib) diff --git a/paddle/fluid/eager/accumulation/gradient_accumulation.cc b/paddle/fluid/eager/accumulation/gradient_accumulation.cc index 723bf5387c60a9552f74aecc378b0cafed957d96..9bc24dd28756a54eefcdc4f75c0450f073c257ae 100644 --- a/paddle/fluid/eager/accumulation/gradient_accumulation.cc +++ b/paddle/fluid/eager/accumulation/gradient_accumulation.cc @@ -116,6 +116,22 @@ class TensorAddFunctor : public boost::static_visitor<> { } #endif +#ifdef PADDLE_WITH_MLU + void operator()(const paddle::platform::MLUPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#else + void operator()(const paddle::platform::MLUPlace& place) { + PADDLE_THROW(paddle::platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + #ifdef PADDLE_WITH_IPU void operator()(const paddle::platform::IPUPlace& place) { PADDLE_THROW(paddle::platform::errors::PermissionDenied( diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 5e450234c405cd9a9ade2e89978ce9566e4d8d67..cef1016aa53403c19791d5ab70db966ab8df3a0a 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -101,6 +101,11 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> { "platform::NPUPinnedPlace is not supported")); } + inline ::DLDevice operator()(const platform::MLUPlace &place) const { + PADDLE_THROW( + platform::errors::Unimplemented("platform::MLUPlace is not supported")); + } + inline ::DLDevice operator()(const platform::CUDAPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ::DLDevice device; diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 9e572614779916bba54dd354c6039c7741dab8bc..d669f2ab11d6c0ea1d46dcad254fadcc754487de 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -490,6 +490,19 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, #else PADDLE_THROW( platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle")); +#endif + } else if (platform::is_mlu_place(place_)) { +#ifdef PADDLE_WITH_MLU + if (IsFastEagerDeletionModeEnabled()) { + gc.reset(new MLUUnsafeFastGarbageCollector( + BOOST_GET_CONST(platform::MLUPlace, place_), max_memory_size)); + } else { + gc.reset(new MLUDefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::MLUPlace, place_), max_memory_size)); + } +#else + PADDLE_THROW( + platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle")); #endif } } diff --git a/paddle/fluid/framework/garbage_collector.cc b/paddle/fluid/framework/garbage_collector.cc index 06d1ef84c19559161fb044a194f8a5c2873469e6..b2d976fea047662da96ea290ee47f30017b5aaf3 100644 --- a/paddle/fluid/framework/garbage_collector.cc +++ b/paddle/fluid/framework/garbage_collector.cc @@ -152,6 +152,56 @@ void NPUUnsafeFastGarbageCollector::ClearCallback( #endif +#ifdef PADDLE_WITH_MLU +MLUDefaultStreamGarbageCollector::MLUDefaultStreamGarbageCollector( + const platform::MLUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void MLUDefaultStreamGarbageCollector::Wait() const { + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); +} + +void MLUDefaultStreamGarbageCollector::ClearCallback( + const std::function &callback) { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); +} +MLUUnsafeFastGarbageCollector::MLUUnsafeFastGarbageCollector( + const platform::MLUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void MLUUnsafeFastGarbageCollector::ClearCallback( + const std::function &callback) { + callback(); +} + +MLUStreamGarbageCollector::MLUStreamGarbageCollector( + const platform::MLUPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + platform::MLUDeviceGuard guard(place.device); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueCreate(&stream_)); + callback_manager_.reset( + new platform::StreamCallbackManager(stream_)); +} + +MLUStreamGarbageCollector::~MLUStreamGarbageCollector() { + auto place = BOOST_GET_CONST(platform::MLUPlace, this->dev_ctx_->GetPlace()); + platform::MLUDeviceGuard guard(place.device); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueDestroy(stream_)); +} + +mluStream MLUStreamGarbageCollector::stream() const { return stream_; } + +void MLUStreamGarbageCollector::Wait() const { callback_manager_->Wait(); } + +void MLUStreamGarbageCollector::ClearCallback( + const std::function &callback) { + callback_manager_->AddCallback(callback); +} +#endif + int64_t GetEagerDeletionThreshold() { return FLAGS_eager_delete_tensor_gb < 0 ? -1 diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index 0cfeda37c222e71af551b09e704cd8f974a3fbe2..dbb3ab7e9e69e145bf3c8c0e2e32b0abbae9319c 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -22,6 +22,9 @@ #include "gflags/gflags.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif namespace paddle { namespace platform { @@ -163,6 +166,46 @@ class NPUUnsafeFastGarbageCollector : public GarbageCollector { }; #endif +#ifdef PADDLE_WITH_MLU +class MLUDefaultStreamGarbageCollector : public GarbageCollector { + public: + MLUDefaultStreamGarbageCollector(const platform::MLUPlace &place, + size_t max_memory_size); + + void Wait() const override; + + protected: + void ClearCallback(const std::function &callback) override; +}; + +class MLUUnsafeFastGarbageCollector : public GarbageCollector { + public: + MLUUnsafeFastGarbageCollector(const platform::MLUPlace &place, + size_t max_memory_size); + + protected: + void ClearCallback(const std::function &callback) override; +}; +class MLUStreamGarbageCollector : public GarbageCollector { + public: + MLUStreamGarbageCollector(const platform::MLUPlace &place, + size_t max_memory_size); + + ~MLUStreamGarbageCollector(); + + void Wait() const override; + + mluStream stream() const; + + protected: + void ClearCallback(const std::function &callback) override; + + private: + mluStream stream_; + std::unique_ptr> callback_manager_; +}; +#endif + template void GarbageCollector::Add(Container &&objs) { Add(std::forward(objs), []() {}); diff --git a/paddle/fluid/framework/library_type.h b/paddle/fluid/framework/library_type.h index f7539aa4859578fe8ba50e328d7c2f9a85540c91..6fdd128b0d3bfb50f55c6fa59d40a3242869be76 100644 --- a/paddle/fluid/framework/library_type.h +++ b/paddle/fluid/framework/library_type.h @@ -67,6 +67,8 @@ inline LibraryType StringToLibraryType(const char* ctype) { return LibraryType::kPlain; } else if (s == std::string("CUDA")) { return LibraryType::kPlain; + } else if (s == std::string("MLU")) { + return LibraryType::kPlain; } else { PADDLE_THROW(platform::errors::Unimplemented( "Unknown LibraryType string (%s), only support library type string " diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 39496cb26776e765f9675e5041cd13104d374c6a..c45bf32d8b710cb35ec5f86a4a8ba2e1078537e6 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -336,6 +336,9 @@ struct OpKernelRegistrarFunctorEx { return GetResultHelper(out, xpu); } + bool GetResult(const framework::Tensor& out, + const platform::MLUPlace& mlu) const { + PADDLE_THROW( + platform::errors::Unimplemented("Not supported on place (%s) ", mlu)); + return true; + } + bool GetResult(const framework::Tensor& out, const platform::CUDAPlace& gpu) const { return GetResultHelper(out, gpu); @@ -824,6 +831,10 @@ struct BothFalseVisitor : public boost::static_visitor<> { // TODO(zhiqiu) } + void VisitorImpl(const platform::MLUPlace& mlu) const { + PADDLE_THROW(platform::errors::Unimplemented("MLUPlace is not supported")); + } + void VisitorImpl(const platform::CPUPlace& cpu) const { int num = in_.numel(); const bool* in_ptr = in_.data(); diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 73829898be961d0f6a3ee1a43d6f80c0bf7d3376..575e2171652a258fd966eecb828ac171fdc7ecfb 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -30,6 +30,9 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" #endif #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif namespace paddle { namespace framework { @@ -239,6 +242,14 @@ void TensorFromVector(const std::vector& src, reinterpret_cast(ctx).stream()); } #endif +#ifdef PADDLE_WITH_MLU + if (platform::is_mlu_place(dst_place)) { + memory::Copy( + BOOST_GET_CONST(platform::MLUPlace, dst_place), dst_ptr, src_place, + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } // The fully specialized function should be inline to avoid @@ -371,6 +382,14 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, size, nullptr); } #endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src.place())) { // NOLINT + memory::Copy( + dst_place, dst_ptr, BOOST_GET_CONST(platform::MLUPlace, src.place()), + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } template <> @@ -412,6 +431,14 @@ inline void TensorToVector(const Tensor& src, BOOST_GET_CONST(platform::NPUPlace, src.place()), src_ptr, size, nullptr); } +#endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(src.place())) { // NOLINT + memory::Copy( + dst_place, dst_ptr, BOOST_GET_CONST(platform::MLUPlace, src.place()), + src_ptr, size, + reinterpret_cast(ctx).stream()); + } #endif for (unsigned int i = 0; i < src.numel(); i++) { (*dst)[i] = static_cast(array[i]); diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 6aad54fba86e481937f0462aef4cbbc35932f023..d98609273a61f690069eec860fee99de51aa9707 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -125,6 +125,23 @@ class TensorAddFunctor : public boost::static_visitor<> { } #endif +#ifdef PADDLE_WITH_MLU + void operator()(const platform::MLUPlace& place) { + // TODO(fwg): SUPPORT it + PADDLE_THROW(platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#else + void operator()(const platform::MLUPlace& place) { + PADDLE_THROW(platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + #ifdef PADDLE_WITH_ASCEND_CL void operator()(const platform::NPUPlace& place) { // TODO(zhiqiu): SUPPORT it diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 1d06a63e38f8d1ec4ed52b158fbfd62c135ac59c..682916a9b323b638dfe8c7bfb1e675ab133bab15 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -132,6 +132,16 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use NPU device since it's not compiled with NPU," "Please recompile or reinstall Paddle with NPU support.")); +#endif + } else if (platform::is_mlu_place(place)) { +#if defined(PADDLE_WITH_MLU) + gc.reset(new framework::MLUDefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::MLUPlace, place), 0)); + VLOG(10) << "Created GarbageCollector at " << place; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use MLU device since it's not compiled with MLU," + "Please recompile or reinstall Paddle with MLU support.")); #endif } else { PADDLE_THROW(platform::errors::PreconditionNotMet( @@ -207,6 +217,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with NPU if use NPUPlace.")); +#endif + } else if (platform::is_mlu_place(place)) { +#ifdef PADDLE_WITH_MLU + platform::SetMLUDeviceId( + BOOST_GET_CONST(platform::MLUPlace, place).device); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with MLU if use MLUPlace.")); #endif } diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 3e9c2ba269fe2f0d3bf8f87da308883687f9b427..a53c6a8dbeb12a1de64c3207632d0bc5d0158a00 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -56,6 +56,10 @@ #include "paddle/fluid/platform/device/ipu/ipu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif + PADDLE_DEFINE_EXPORTED_int64( gpu_allocator_retry_time, 10000, "The retry time (milliseconds) when allocator fails " @@ -167,6 +171,11 @@ class AllocatorFacadePrivate { InitNaiveBestFitNPUAllocator(platform::NPUPlace(dev_id)); } InitNaiveBestFitNPUPinnedAllocator(); +#endif +#ifdef PADDLE_WITH_MLU + for (int dev_id = 0; dev_id < platform::GetMLUDeviceCount(); ++dev_id) { + InitNaiveBestFitMLUAllocator(platform::MLUPlace(dev_id)); + } #endif break; } @@ -201,6 +210,11 @@ class AllocatorFacadePrivate { for (int dev_id = 0; dev_id < platform::GetIPUDeviceCount(); ++dev_id) { InitNaiveBestFitIPUAllocator(platform::IPUPlace(dev_id)); } +#endif +#ifdef PADDLE_WITH_MLU + for (int dev_id = 0; dev_id < platform::GetMLUDeviceCount(); ++dev_id) { + InitNaiveBestFitMLUAllocator(platform::MLUPlace(dev_id)); + } #endif break; } @@ -228,6 +242,11 @@ class AllocatorFacadePrivate { InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id)); } InitNaiveBestFitCUDAPinnedAllocator(); +#endif +#ifdef PADDLE_WITH_MLU + for (int dev_id = 0; dev_id < platform::GetMLUDeviceCount(); ++dev_id) { + InitNaiveBestFitMLUAllocator(platform::MLUPlace(dev_id)); + } #endif break; } @@ -637,6 +656,12 @@ class AllocatorFacadePrivate { } #endif +#ifdef PADDLE_WITH_MLU + void InitNaiveBestFitMLUAllocator(platform::MLUPlace p) { + allocators_[p] = std::make_shared(p); + } +#endif + #ifdef PADDLE_WITH_ASCEND_CL void InitNaiveBestFitNPUAllocator(platform::NPUPlace p) { allocators_[p] = std::make_shared(p); @@ -673,6 +698,13 @@ class AllocatorFacadePrivate { platform::CUDAPlace p(i); system_allocators_[p] = std::make_shared(p); } +#endif +#ifdef PADDLE_WITH_MLU + int device_count = platform::GetMLUDeviceCount(); + for (int i = 0; i < device_count; ++i) { + platform::XPUPlace p(i); + system_allocators_[p] = std::make_shared(p); + } #endif } @@ -705,6 +737,12 @@ class AllocatorFacadePrivate { places.emplace_back(platform::IPUPlace(dev_id)); } #endif +#ifdef PADDLE_WITH_MLU + int device_count = platform::GetMLUDeviceCount(); + for (int dev_id = 0; dev_id < device_count; ++dev_id) { + places.emplace_back(platform::MLUPlace(dev_id)); + } +#endif for (auto& p : places) { zero_size_allocators_[p] = std::make_shared(p); diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 41dcf277d7a11e6fa1b90e103b148abffa28704a..d2319dacdd33f8ddd9eb1df625c76cd4b7e59b96 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -35,6 +35,9 @@ #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/platform/device/npu/npu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif PADDLE_DEFINE_EXPORTED_bool( init_allocated_mem, false, @@ -651,6 +654,134 @@ uint64_t Release( #endif } +// For MLU +#ifdef PADDLE_WITH_MLU +class MLUBuddyAllocatorList { + private: + MLUBuddyAllocatorList() : devices_(platform::GetMLUSelectedDevices()) { + auto mlu_num = devices_.size(); + allocators_.resize(mlu_num); + init_flags_.reserve(mlu_num); + for (size_t i = 0; i < mlu_num; ++i) { + init_flags_.emplace_back(new std::once_flag()); + } + } + + static MLUBuddyAllocatorList *CreateNewInstance() { + return new MLUBuddyAllocatorList(); + } + + public: + static MLUBuddyAllocatorList *Instance() { + static auto *instance = CreateNewInstance(); + return instance; + } + + BuddyAllocator *Get(int mlu_id) { + auto pos = std::distance( + devices_.begin(), std::find(devices_.begin(), devices_.end(), mlu_id)); + PADDLE_ENFORCE_LT(pos, devices_.size(), + platform::errors::OutOfRange( + "The index exceeds the size of devices, the size of " + "devices is %d, the index is %d", + devices_.size(), pos)); + + std::call_once(*init_flags_[pos], [this, pos] { + platform::SetMLUDeviceId(devices_[pos]); + allocators_[pos].reset(new BuddyAllocator( + std::unique_ptr( + new detail::MLUAllocator(devices_[pos])), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize())); + VLOG(10) << "\n\nNOTE:\n" + << "You can set GFlags environment variable " + << "(mlu reuse gpu GFlags) " + << "'FLAGS_fraction_of_gpu_memory_to_use' " + << "or 'FLAGS_initial_gpu_memory_in_mb' " + << "or 'FLAGS_reallocate_gpu_memory_in_mb' " + << "to change the memory size for MLU usage.\n" + << "Current 'FLAGS_fraction_of_gpu_memory_to_use' value is " + << FLAGS_fraction_of_gpu_memory_to_use + << ". Current 'FLAGS_initial_gpu_memory_in_mb' value is " + << FLAGS_initial_gpu_memory_in_mb + << ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is " + << FLAGS_reallocate_gpu_memory_in_mb << "\n\n"; + }); + + return allocators_[pos].get(); + } + + private: + std::vector devices_; + std::vector> init_flags_; + std::vector> allocators_; +}; + +BuddyAllocator *GetMLUBuddyAllocator(int mlu_id) { + return MLUBuddyAllocatorList::Instance()->Get(mlu_id); +} +#endif + +template <> +size_t Used(const platform::MLUPlace &place) { +#ifdef PADDLE_WITH_MLU + return GetMLUBuddyAllocator(place.device)->Used(); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'MLUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +void *Alloc(const platform::MLUPlace &place, size_t size) { +#ifdef PADDLE_WITH_MLU + auto *buddy_allocator = GetMLUBuddyAllocator(place.device); + auto *ptr = buddy_allocator->Alloc(size); + if (ptr == nullptr) { + platform::MLUDeviceGuard(place.device); + size_t avail = 0, total = 0; + platform::MLUMemoryUsage(&avail, &total); + PADDLE_THROW(platform::errors::ResourceExhausted( + "Cannot allocate %s in MLU %d, avaliable %s, total %s, MLUMinChunkSize " + "%s, MLUMinChunkSize %s, MLU memory used: %s.", + string::HumanReadableSize(size), place.device, + string::HumanReadableSize(avail), string::HumanReadableSize(total), + string::HumanReadableSize(buddy_allocator->GetMinChunkSize()), + string::HumanReadableSize(buddy_allocator->GetMaxChunkSize()), + string::HumanReadableSize(Used(place)))); + } else { + if (FLAGS_init_allocated_mem) { + cnrtMemset(ptr, 0xEF, size); + } + } + return ptr; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'MLUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +void Free(const platform::MLUPlace &place, void *p, + size_t size) { +#ifdef PADDLE_WITH_MLU + VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); + GetMLUBuddyAllocator(place.device)->Free(p); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'MLUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +uint64_t Release(const platform::MLUPlace &place) { +#ifdef PADDLE_WITH_MLU + return GetMLUBuddyAllocator(place.device)->Release(); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'MLUPlace' is not supported in CPU only device.")); +#endif +} + struct AllocVisitor : public boost::static_visitor { inline explicit AllocVisitor(size_t size) : size_(size) {} diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc index 1fe85dd699acf18387482d296c2c30f3bb2415cb..7d5cb5200a6a4ab47ef2140aad54533da787f727 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc @@ -77,6 +77,21 @@ TEST(NaiveBestFitAllocatorTest, NpuAlloc) { } #endif +#ifdef PADDLE_WITH_MLU +TEST(NaiveBestFitAllocatorTest, MluAlloc) { + NaiveBestFitAllocator alloc{platform::MLUPlace(0)}; + { + size_t size = (1 << 20); + auto allocation = alloc.Allocate(size); + } + sleep(10); + alloc.Release(platform::MLUPlace(0)); + + size_t size = (1 << 20); + auto allocation = alloc.Allocate(size); + alloc.Release(platform::MLUPlace(0)); +} +#endif } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/CMakeLists.txt b/paddle/fluid/memory/detail/CMakeLists.txt index e9631ee739b9b8089a963a6aa84a9837010ad639..a039cd8f4186006915e3d5b46543d0b33cfe080a 100644 --- a/paddle/fluid/memory/detail/CMakeLists.txt +++ b/paddle/fluid/memory/detail/CMakeLists.txt @@ -8,6 +8,8 @@ elseif(WITH_ROCM) hip_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info place) elseif(${WITH_ASCEND_CL}) cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info npu_info place) +elseif(WITH_MLU) + cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info mlu_info place) else() cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info place) endif() diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index e714a020165d17a0ceb5c93ccf01ee99a147cd95..e9b715c5cc3cf00b394a9cdffec836463c678b21 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -18,12 +18,16 @@ limitations under the License. */ #include "gflags/gflags.h" #include "glog/logging.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_MLU) DECLARE_uint64(reallocate_gpu_memory_in_mb); #endif #ifdef PADDLE_WITH_ASCEND_CL DECLARE_uint64(reallocate_gpu_memory_in_mb); #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif namespace paddle { namespace memory { @@ -259,6 +263,21 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( } } #endif +#ifdef PADDLE_WITH_MLU + if (system_allocator_->UseGpu()) { + if ((total_used_ + total_free_) == 0) { + // Compute the allocation size for mlu for the first allocation. + allocate_bytes = std::max(platform::MLUInitAllocSize(), request_bytes); + } else { + // Compute the re-allocation size, we store the re-allocation size when + // user set FLAGS_reallocate_gpu_memory_in_mb to fix value. + if (realloc_size_ == 0 || FLAGS_reallocate_gpu_memory_in_mb == 0ul) { + realloc_size_ = platform::MLUReallocSize(); + } + allocate_bytes = std::max(realloc_size_, request_bytes); + } + } +#endif // Allocate a new block void* p = system_allocator_->Alloc(&index, allocate_bytes); diff --git a/paddle/fluid/memory/detail/buddy_allocator_test.cc b/paddle/fluid/memory/detail/buddy_allocator_test.cc index cd152843553a9f7abda9c4a81e7231a919be9c79..7d19115940fee0dde850eb5f432d89a84e9e8023 100644 --- a/paddle/fluid/memory/detail/buddy_allocator_test.cc +++ b/paddle/fluid/memory/detail/buddy_allocator_test.cc @@ -26,9 +26,12 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADDLE_WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_uint64(initial_gpu_memory_in_mb); DECLARE_uint64(reallocate_gpu_memory_in_mb); @@ -370,6 +373,186 @@ TEST(BuddyAllocator, NpuFraction) { } #endif +#ifdef PADDLE_WITH_MLU +TEST(BuddyAllocator, MluFraction) { + // In a 16 GB machine, the pool size will be about 160 MB + FLAGS_fraction_of_gpu_memory_to_use = 0.01; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + // Less than pool size + TestBuddyAllocator(&buddy_allocator, 10); + TestBuddyAllocator(&buddy_allocator, 10 << 10); + TestBuddyAllocator(&buddy_allocator, 10 << 20); + buddy_allocator.Release(); + + // Greater than max chunk size + TestBuddyAllocator(&buddy_allocator, 600 << 20, + /* use_system_allocator = */ true); + TestBuddyAllocator(&buddy_allocator, 1 * static_cast(1 << 30), + /* use_system_allocator = */ true); +} + +TEST(BuddyAllocator, InitRealloc) { + FLAGS_initial_gpu_memory_in_mb = 100; + FLAGS_reallocate_gpu_memory_in_mb = 50; + + EXPECT_EQ(platform::MLUMaxChunkSize(), static_cast(100 << 20)); + + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + // Less then initial size and reallocate size + TestBuddyAllocator(&buddy_allocator, 10 << 20); + // Between initial size and reallocate size and not exceed pool + TestBuddyAllocator(&buddy_allocator, 80 << 20); + TestBuddyAllocator(&buddy_allocator, 99 << 20); + // Greater than max chunk size + TestBuddyAllocator(&buddy_allocator, 101 << 20, + /* use_system_allocator = */ true); + TestBuddyAllocator(&buddy_allocator, 1 * static_cast(1 << 30), + /* use_system_allocator = */ true); +} + +TEST(BuddyAllocator, ReallocSizeGreaterThanInit) { + FLAGS_initial_gpu_memory_in_mb = 5; + FLAGS_reallocate_gpu_memory_in_mb = 10; + + EXPECT_EQ(platform::MLUMaxChunkSize(), static_cast(10 << 20)); + + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + // Less than initial size and reallocate size + TestBuddyAllocator(&buddy_allocator, 1 << 20); + // Between initial size and reallocate size and exceed pool + TestBuddyAllocator(&buddy_allocator, 6 << 20); + TestBuddyAllocator(&buddy_allocator, 8 << 20); + TestBuddyAllocator(&buddy_allocator, 9 << 20); + // Greater than max trunk size + TestBuddyAllocator(&buddy_allocator, 11 << 20, + /* use_system_allocator = */ true); + TestBuddyAllocator(&buddy_allocator, 1 * static_cast(1 << 30), + /* use_system_allocator = */ true); +} + +TEST(BuddyAllocator, FractionRefillPool) { + FLAGS_fraction_of_gpu_memory_to_use = 0.6; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + size_t max_chunk_size = platform::MLUMaxChunkSize(); + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), max_chunk_size); + + // Less than pool size + int* p0 = TestBuddyAllocator(&buddy_allocator, max_chunk_size - 1000, + /* use_system_allocator = */ false, + /* free_ptr = */ false); + // Max chunk size should be same during allocation + EXPECT_EQ(max_chunk_size, buddy_allocator.GetMaxChunkSize()); + + size_t alloc = + platform::MLUAvailableMemToAlloc() * FLAGS_fraction_of_gpu_memory_to_use; + // Exceed pool trigger refilling size of fraction of avaiable mlu, and should + // be able to alloc 60% of the remaining MLU + int* p1 = TestBuddyAllocator(&buddy_allocator, alloc, + /* use_system_allocator = */ false, + /* free_ptr = */ false); + // Max chunk size should be same during allocation + EXPECT_EQ(max_chunk_size, buddy_allocator.GetMaxChunkSize()); + + alloc = + platform::MLUAvailableMemToAlloc() * FLAGS_fraction_of_gpu_memory_to_use; + // Exceed pool trigger refilling size of fraction of avaiable mlu, and should + // be able to alloc 60% of the remaining MLU + TestBuddyAllocator(&buddy_allocator, alloc, + /* use_system_allocator = */ false); + // Max chunk size should be same during allocation + EXPECT_EQ(max_chunk_size, buddy_allocator.GetMaxChunkSize()); + + buddy_allocator.Free(p0); + buddy_allocator.Free(p1); +} + +TEST(BuddyAllocator, AllocFromAvailable) { + FLAGS_fraction_of_gpu_memory_to_use = 0.7; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + size_t total = 0, available = 0; + platform::SetMLUDeviceId(0); + platform::MLUMemoryUsage(&available, &total); + + // Take half of available MLU + void* p; + + cnrtStatus result = cnrtMalloc(&p, available >> 1); + EXPECT_TRUE(result == cnrtSuccess); + + // BuddyAllocator should be able to alloc the remaining MLU + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + TestBuddyAllocator(&buddy_allocator, 10); + TestBuddyAllocator(&buddy_allocator, 10 << 10); + TestBuddyAllocator(&buddy_allocator, 10 << 20); + TestBuddyAllocator(&buddy_allocator, static_cast(1 << 30)); + + if (p) { + EXPECT_TRUE(cnrtFree(p) == cnrtSuccess); + } +} + +TEST(BuddyAllocator, AllocFromAvailableWhenFractionIsOne) { + FLAGS_fraction_of_gpu_memory_to_use = 1.0; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + void* p = nullptr; + + EXPECT_TRUE(cnrtMalloc(&p, static_cast(1) << 30) == cnrtSuccess); + + // BuddyAllocator should be able to alloc the remaining MLU + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + TestBuddyAllocator(&buddy_allocator, static_cast(1) << 30); + TestBuddyAllocator(&buddy_allocator, static_cast(1) << 30); + + if (p) { + EXPECT_TRUE(cnrtFree(p) == cnrtSuccess); + } +} + +TEST(BuddyAllocator, Release) { + // In a 8 GB machine, the pool size will be about 800 MB + FLAGS_fraction_of_gpu_memory_to_use = 0.1; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + BuddyAllocator buddy_allocator( + std::unique_ptr(new MLUAllocator(0)), + platform::MLUMinChunkSize(), platform::MLUMaxChunkSize()); + + // Less than pool size + TestBuddyAllocator(&buddy_allocator, 10); + TestBuddyAllocator(&buddy_allocator, 10 << 10); + TestBuddyAllocator(&buddy_allocator, 50 << 20); + + buddy_allocator.Release(); +} +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index b300f936f7a6831a2c1a4b158578f4426f2abae6..773122de6c3198b09c33241a0d6a09e9357f65a3 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -30,6 +30,9 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/enforce.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" @@ -365,6 +368,68 @@ bool NPUPinnedAllocator::UseGpu() const { return false; } #endif +#ifdef PADDLE_WITH_MLU +void* MLUAllocator::Alloc(size_t* index, size_t size) { + if (size <= 0) return nullptr; + + void* p; + auto result = platform::RecordedMLUMalloc(&p, size, mlu_id_); + + if (result == cnrtSuccess) { + *index = 0; + mlu_alloc_size_ += size; + return p; + } else { + size_t avail, total, actual_avail, actual_total; + bool is_limited = platform::RecordedMLUMemGetInfo( + &avail, &total, &actual_avail, &actual_total, mlu_id_); + size_t allocated = total - avail; + + std::string err_msg; + if (is_limited) { + auto limit_size = (total >> 20); + err_msg = string::Sprintf( + "\n 3) Set environment variable `FLAGS_gpu_memory_limit_mb` to a " + "larger value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the " + "maximum MLU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size, limit_size); + } + + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on MLU %d. " + "Cannot allocate %s memory on MLU %d, %s memory has been allocated and " + "available memory is only %s.\n\n" + "Please check whether there is any other process using MLU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another MLU.\n" + "2. If no, please try one of the following suggestions:\n" + " 1) Decrease the batch size of your model.\n" + " 2) FLAGS_fraction_of_gpu_memory_to_use is %.2lf now, " + "please set it to a higher value but less than 1.0.\n" + " The command is " + "`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.%s\n\n", + mlu_id_, string::HumanReadableSize(size), mlu_id_, + string::HumanReadableSize(allocated), string::HumanReadableSize(avail), + mlu_id_, FLAGS_fraction_of_gpu_memory_to_use, err_msg)); + } +} + +void MLUAllocator::Free(void* p, size_t size, size_t index) { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The index should be 0, index is %d", index)); + PADDLE_ENFORCE_GE(mlu_alloc_size_, size, + platform::errors::InvalidArgument( + "The size of memory (%d) to free exceeds the size of " + "allocated gpu memory (%d)", + size, mlu_alloc_size_)); + mlu_alloc_size_ -= size; + + platform::RecordedMLUFree(p, size, mlu_id_); +} + +bool MLUAllocator::UseGpu() const { return true; } +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index 92042f0bbae9f0d29d15b9ed266f57cfa7594412..975e2891b2472ad4aeb5c4a7d6f676c516350545 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -92,6 +92,21 @@ class NPUPinnedAllocator : public SystemAllocator { }; #endif +#ifdef PADDLE_WITH_MLU +class MLUAllocator : public SystemAllocator { + public: + explicit MLUAllocator(int mlu_id) : mlu_id_(mlu_id) {} + + virtual void* Alloc(size_t* index, size_t size); + virtual void Free(void* p, size_t size, size_t index); + virtual bool UseGpu() const; + + private: + size_t mlu_alloc_size_ = 0; + int mlu_id_; +}; +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator_test.cc b/paddle/fluid/memory/detail/system_allocator_test.cc index bb7f47f9d30ec431cca13c374c28541aabc08de2..d818459fb03a0a0e442c35b67c744b3e124c2e83 100644 --- a/paddle/fluid/memory/detail/system_allocator_test.cc +++ b/paddle/fluid/memory/detail/system_allocator_test.cc @@ -22,6 +22,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/device/gpu/gpu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/enforce.h" +#endif DECLARE_bool(use_pinned_memory); @@ -92,3 +95,23 @@ TEST(NPUAllocator, Alloc) { TestAllocator(&a, 1); } #endif + +#ifdef PADDLE_WITH_MLU +TEST(MLUAllocator, Alloc) { + paddle::memory::detail::MLUAllocator a(0); + TestAllocator(&a, 2048); + TestAllocator(&a, 0); +} + +TEST(MLUAllocator, AllocFailure) { + paddle::memory::detail::MLUAllocator allocator(0); + size_t index; + size_t alloc_size = (static_cast(1) << 40); // Very large number + try { + allocator.Alloc(&index, alloc_size); + ASSERT_TRUE(false); + } catch (paddle::memory::allocation::BadAlloc&) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetLastError()); + } +} +#endif diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 4de81435881ed4534bdda8f87a80307edc9eb9ae..2814f2f9501a8bef01526e9b9bc89e7d63fdca11 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -22,6 +22,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device/xpu/xpu_header.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif + namespace paddle { namespace memory { @@ -631,5 +635,91 @@ void Copy( #endif +#ifdef PADDLE_WITH_MLU +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::MLUPlace src_place, + const void* src, size_t num, + mluStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetMLUDeviceId(src_place.device); + if (stream) { + VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place << " by mlu stream(" << stream << ")"; + platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU"); + platform::MLUMemcpyD2HAsync(dst, src, num, stream); + } else { + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place; + platform::RecordEvent record_event("MLUMemcpyD2HSync:MLU->CPU"); + platform::MLUMemcpyD2HSync(dst, src, num); + } +} + +template <> +void Copy(platform::MLUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num, + mluStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetMLUDeviceId(dst_place.device); + if (stream) { + VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place << " by mlu stream(" << stream << ")"; + platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU"); + platform::MLUMemcpyH2DAsync(dst, src, num, stream); + } else { + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place; + platform::RecordEvent record_event("MLUMemcpyH2DSync:CPU->MLU"); + platform::MLUMemcpyH2DSync(dst, src, num); + } +} + +template <> +void Copy(platform::MLUPlace dst_place, + void* dst, + platform::MLUPlace src_place, + const void* src, size_t num, + mluStream stream) { + if (UNLIKELY(num == 0)) return; + + if (dst_place == src_place) { + platform::SetMLUDeviceId(dst_place.device); + if (stream) { + VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place << " by mlu stream(" << stream << ")"; + platform::RecordEvent record_event( + "MLUMemcpyD2DAsync(same_mlu):MLU->MLU"); + platform::MLUMemcpyD2DAsync(dst, src, num, stream); + } else { + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place; + platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU"); + platform::MLUMemcpyD2DSync(dst, src, num); + } + } else { + if (stream) { + VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place << " by mlu stream(" << stream << ")"; + platform::RecordEvent record_event("MLUMemcpyPeerAsync:MLU->MLU"); + platform::MLUMemcpyPeerAsync(dst, dst_place.device, src, src_place.device, + num, stream); + } else { + VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place + << " to " << dst_place; + platform::RecordEvent record_event("MLUMemcpyPeerSync:MLU->MLU"); + platform::MLUMemcpyPeerSync(dst, dst_place.device, src, src_place.device, + num); + } + } +} + +#endif // PADDLE_WITH_MLU + } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/memcpy.h b/paddle/fluid/memory/memcpy.h index 7d2d2526ab12453b4dcf0dc8b52366b43a3c832e..31d1a50e778f8c86400163a774af6dc04dce10ed 100644 --- a/paddle/fluid/memory/memcpy.h +++ b/paddle/fluid/memory/memcpy.h @@ -16,6 +16,9 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif namespace paddle { namespace memory { @@ -74,5 +77,25 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, aclrtStream stream); #endif +#ifdef PADDLE_WITH_MLU +/** + * \brief Copy memory from one place to another place. + * + * \param[in] DstPlace Destination allocation place (CPU or MLU). + * \param[in] dst Destination memory address. + * \param[in] SrcPlace Source allocation place (CPU or MLU). + * \param[in] src Source memory address. + * \param[in] num memory size in bytes to copy. + * \param[in] stream MLU stream. + * + * \note For MLU memory copy, MLU stream need to be specified + * for asynchronously memory copy. + * + */ +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, + mluStream stream); +#endif + } // namespace memory } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6b8567589872ffd7ce8788b5af17dc5ead42043e..1f1bc01c40d0d6bfcd54e3a412f7e5e2caa29c5a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -52,6 +52,10 @@ if (WITH_LITE) add_subdirectory(lite) endif() +if (WITH_MLU) + add_subdirectory(mlu) +endif() + if(WITH_CINN) add_subdirectory(cinn) endif() diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ad581cf4ca2b84dcae30b60901bb6a72555ade9 --- /dev/null +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -0,0 +1,100 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ActivationMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto& dev_ctx = ctx.template device_context(); + + output->mutable_data(ctx.GetPlace()); + + MLUCnnlActivationDesc act_desc(act_mode, alpha_); + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->type())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->type())); + + MLUCnnl::Active(dev_ctx, act_desc.get(), input_desc.get(), + reinterpret_cast(input->data()), + output_desc.get(), + reinterpret_cast(output->data())); + } + + private: + float alpha_ = 1.0; +}; + +template +class ActivationGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto& dev_ctx = ctx.template device_context(); + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(dout->type())); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(out->type())); + MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(dx->type())); + MLUCnnlActivationDesc act_desc(act_mode, alpha_); + MLUCnnl::ActiveGrad( + dev_ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, + dout_desc.get(), reinterpret_cast(dout->data()), + out_desc.get(), reinterpret_cast(out->data()), + dx_desc.get(), reinterpret_cast(dx->data())); + } + + private: + float alpha_ = 1.0; +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL( + relu, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + relu_grad, ops::ActivationGradMLUKernel, + ops::ActivationGradMLUKernel); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index daa4efa02ac5081e6ddcd0ed45f6e98c826557ba..1efddc4818671c089a7f86f3e1e4ce16590f01c4 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -187,6 +187,13 @@ void set_constant_with_place( framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + PADDLE_THROW(platform::errors::Unimplemented("MLUPlace is not supported")); +} + template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, diff --git a/paddle/fluid/operators/mlu/CMakeLists.txt b/paddle/fluid/operators/mlu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3fc411d6d13fa60169f0df8ccf3fe8d95af7c76e --- /dev/null +++ b/paddle/fluid/operators/mlu/CMakeLists.txt @@ -0,0 +1,5 @@ + +IF(WITH_MLU) + cc_library(mlu_baseop SRCS mlu_baseop.cc DEPS neuware_lib) + cc_test(activation_op_mlu_test SRCS activation_op_mlu_test.cc DEPS op_registry activation_op scope device_context executor) +ENDIF() diff --git a/paddle/fluid/operators/mlu/activation_op_mlu_test.cc b/paddle/fluid/operators/mlu/activation_op_mlu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e392bcc75e8246774d3f2c65740d5d02b5e4a14 --- /dev/null +++ b/paddle/fluid/operators/mlu/activation_op_mlu_test.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace fw = paddle::framework; +namespace plat = paddle::platform; +namespace math = paddle::operators::math; + +USE_OP(relu); +USE_OP_DEVICE_KERNEL(relu, MLU); + +// relu +template +inline T relu(T x) { + return x > 0 ? x : 0.; +} + +template +inline T relu_grad_dx(T x, T out, T dout) { + return out > 0 ? dout : 0; +} + +template +void Compare(fw::Scope* scope, const plat::DeviceContext& ctx, + std::string op_type) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + const int num = 10; + std::vector init_x; + for (int64_t i = 0; i < num * num; ++i) { + init_x.push_back(static_cast(i - 50)); + } + TensorFromVector(init_x, ctx, tensor_x); + tensor_x->Resize({num, num}); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + fw::AttributeMap attrs; + auto op = fw::OpRegistry::CreateOp(op_type, {{"X", {"X"}}}, + {{"Out", {"Out"}}}, attrs); + op->Run(*scope, place); + + ctx.Wait(); + + // eval time + struct timeval start, end; + gettimeofday(&start, NULL); + + for (int i = 0; i < 100; i++) { + op->Run(*scope, place); + } + + ctx.Wait(); + + gettimeofday(&end, NULL); + int micros = + (((end.tv_sec - start.tv_sec) * 1000000) + end.tv_usec) - (start.tv_usec); + printf("used time: %d\n", micros / 100); + + // eval value + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + ctx.Wait(); + + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_FLOAT_EQ(out_vec[i], relu(init_x[i])); + } +} + +template +void CompareGrad(fw::Scope* scope, const plat::DeviceContext& ctx, + std::string op_type) { + auto dout = scope->Var("DOut"); + auto tensor_dout = dout->GetMutable(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + const int num = 10; + std::vector init_dout; + for (int64_t i = 0; i < num * num; ++i) { + init_dout.push_back(static_cast(1.0)); + } + + std::vector init_out; + for (int64_t i = 0; i < num * num; ++i) { + init_out.push_back(static_cast(i - 50)); + } + + TensorFromVector(init_dout, ctx, tensor_dout); + tensor_dout->Resize({num, num}); + TensorFromVector(init_out, ctx, tensor_out); + tensor_out->Resize({num, num}); + + auto dx = scope->Var("DX"); + auto tensor_dx = dx->GetMutable(); + + // run + auto place = ctx.GetPlace(); + fw::AttributeMap attrs; + auto op = fw::OpRegistry::CreateOp(op_type, + {{"Out@GRAD", {"DOut"}}, {"Out", {"Out"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + op->Run(*scope, place); + + ctx.Wait(); + + // eval time + struct timeval start, end; + gettimeofday(&start, NULL); + + for (int i = 0; i < 100; i++) { + op->Run(*scope, place); + } + + ctx.Wait(); + + gettimeofday(&end, NULL); + int micros = + (((end.tv_sec - start.tv_sec) * 1000000) + end.tv_usec) - (start.tv_usec); + printf("used time: %d\n", micros / 100); + + // eval value + std::vector dx_vec; + TensorToVector(*tensor_dx, ctx, &dx_vec); + + ctx.Wait(); + + for (uint32_t i = 0; i < dx_vec.size(); i++) { + EXPECT_FLOAT_EQ(dx_vec[i], + relu_grad_dx(dx_vec[i], init_out[i], init_dout[i])); + } +} + +TEST(relu, MLU_fp32) { + fw::Scope scope; + auto* ctx = plat::DeviceContextPool::Instance().Get(plat::MLUPlace(0)); + Compare(&scope, *ctx, "relu"); +} + +TEST(relu_grad, MLU_fp32) { + fw::Scope scope; + auto* ctx = plat::DeviceContextPool::Instance().Get(plat::MLUPlace(0)); + CompareGrad(&scope, *ctx, "relu_grad"); +} diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc new file mode 100644 index 0000000000000000000000000000000000000000..917692bfbd9d5366d19d09ec42458738cbceeb36 --- /dev/null +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -0,0 +1,227 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace operators { + +class MLUCnnlTensorDescPool { + public: + cnnlTensorDescriptor_t Pop() { + cnnlTensorDescriptor_t raw_desc; + if (q_.try_dequeue(raw_desc)) { + return raw_desc; + } else { + cnnlCreateTensorDescriptor(&raw_desc); + return raw_desc; + } + } + + void Recycle(cnnlTensorDescriptor_t desc) { + cnnlResetTensorDescriptor(desc); + q_.enqueue(desc); + } + + ~MLUCnnlTensorDescPool() { + auto size = q_.size_approx(); + if (size > 0) { + std::vector vec(size); + q_.try_dequeue_bulk(vec.data(), size); + for (auto desc : vec) { + cnnlDestroyTensorDescriptor(desc); + } + } + } + + private: + moodycamel::ConcurrentQueue q_; +}; + +static MLUCnnlTensorDescPool g_cnnl_tensor_desc_pool; + +MLUCnnlTensorDesc &MLUCnnlTensorDesc::operator=(MLUCnnlTensorDesc &&rhs) { + if (raw_tensor_desc) { + g_cnnl_tensor_desc_pool.Recycle(raw_tensor_desc); + } + raw_tensor_desc = rhs.raw_tensor_desc; + rhs.raw_tensor_desc = nullptr; + return *this; +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int dim_sizes[], + const cnnlDataType_t tensor_dtype) { + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetTensorDescriptor( + raw_tensor_desc, CNNL_LAYOUT_ARRAY, tensor_dtype, tensor_dim, dim_sizes)); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int dim_sizes[], + const cnnlDataType_t tensor_dtype, + const cnnlTensorLayout_t layout) { + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetTensorDescriptor( + raw_tensor_desc, layout, tensor_dtype, tensor_dim, dim_sizes)); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int dim_sizes[], + const cnnlDataType_t tensor_dtype, + int position) + : MLUCnnlTensorDesc(tensor_dim, dim_sizes, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, int64_cend, dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptor(raw_tensor_desc, CNNL_LAYOUT_ARRAY, tensor_dtype, + tensor_dim, dim_sizes_int32.data())); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype, + const cnnlTensorLayout_t layout) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, int64_cend, dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetTensorDescriptor(raw_tensor_desc, layout, + tensor_dtype, tensor_dim, + dim_sizes_int32.data())); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype, + int position) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, int64_cend, dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptor(raw_tensor_desc, CNNL_LAYOUT_ARRAY, tensor_dtype, + tensor_dim, dim_sizes_int32.data())); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor &tensor, + const cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype) { + auto dims = framework::vectorize(tensor.dims()); + int tensor_dim = dims.size(); + raw_tensor_desc = g_cnnl_tensor_desc_pool.Pop(); + if (tensor_dim == 0) { + int scalar_dims[1] = {1}; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetTensorDescriptor( + raw_tensor_desc, layout, tensor_dtype, 1, scalar_dims)); + } else { + std::vector tensor_dim_sizes_int(dims.begin(), dims.end()); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptor(raw_tensor_desc, layout, tensor_dtype, + tensor_dim, tensor_dim_sizes_int.data())); + } +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor &tensor, + cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype, + int position) + : MLUCnnlTensorDesc(tensor, layout, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUCnnlTensorDesc::MLUCnnlTensorDesc(const Tensor &tensor, + cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype, + int position, float scale) + : MLUCnnlTensorDesc(tensor, layout, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetTensorDescriptorPositionAndScale( + raw_tensor_desc, position, scale)); +} + +MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { + if (raw_tensor_desc) { + g_cnnl_tensor_desc_pool.Recycle(raw_tensor_desc); + } +} + +MLUCnnlActivationDesc::MLUCnnlActivationDesc( + const cnnlActivationMode_t act_mode, const float ceof) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor( + active_desc_, act_mode, CNNL_NOT_PROPAGATE_NAN, ceof)); +} + +const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const { + return active_desc_; +} + +MLUCnnlActivationDesc::~MLUCnnlActivationDesc() { + if (active_desc_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroyActivationDescriptor(active_desc_)); + } +} + +/* static */ void MLUCnnl::Active(const platform::MLUDeviceContext &ctx, + cnnlActivationDescriptor_t active_desc, + const cnnlTensorDescriptor_t input_desc, + const void *input, + const cnnlTensorDescriptor_t output_desc, + void *output) { + cnnlHandle_t handle = ctx.cnnl_handle(); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlActivationForward( + handle, active_desc, NULL, input_desc, input, NULL, output_desc, output)); +} + +/* static */ void MLUCnnl::ActiveGrad( + const platform::MLUDeviceContext &ctx, + cnnlActivationDescriptor_t active_desc, const void *alpha, const void *beta, + const cnnlTensorDescriptor_t y_desc, const void *y, + const cnnlTensorDescriptor_t diff_y_desc, const void *diff_y, + const cnnlTensorDescriptor_t x_desc, const void *x, + const cnnlTensorDescriptor_t diff_x_desc, void *diff_x) { + cnnlHandle_t handle = ctx.cnnl_handle(); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlActivationBackward(handle, active_desc, alpha, y_desc, y, diff_y_desc, + diff_y, x_desc, x, beta, diff_x_desc, diff_x)); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h new file mode 100644 index 0000000000000000000000000000000000000000..e0a2735e0ea4dc855cb3cf3dab2917cd5d040685 --- /dev/null +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -0,0 +1,168 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/platform/device/mlu/enforce.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +using DeviceContextPool = platform::DeviceContextPool; + +template +inline cnnlDataType_t ToCnnlDataType(const T& t) { + auto type = framework::ToDataType(t); + return ToCnnlDataType(type); +} + +template <> +inline cnnlDataType_t ToCnnlDataType(const framework::proto::VarType::Type& t) { + cnnlDataType_t type = CNNL_DTYPE_FLOAT; + switch (t) { + case framework::proto::VarType::FP16: + type = CNNL_DTYPE_HALF; + break; + case framework::proto::VarType::FP32: + type = CNNL_DTYPE_FLOAT; + break; + case framework::proto::VarType::INT8: + type = CNNL_DTYPE_INT8; + break; + case framework::proto::VarType::INT32: + type = CNNL_DTYPE_INT32; + break; + case framework::proto::VarType::INT64: + type = CNNL_DTYPE_INT64; + break; + case framework::proto::VarType::BOOL: + type = CNNL_DTYPE_BOOL; + break; + default: + break; + } + return type; +} + +// Converts (via narrowing) a type T value to a type U, and checks that the +// value has no value change due to the conversion. +template +NarrowT CheckedNarrowing(const WideT& wide) { + NarrowT narrow = wide; + CHECK_EQ(narrow, wide) + << "checked narrowing failed; values not equal post-conversion"; + return narrow; +} + +cnnlDeviceType_t GetCnnlDev(int dev_ordinal); + +using CnnlTensorDesc = cnnlTensorDescriptor_t; + +class MLUCnnlTensorDesc { + public: + MLUCnnlTensorDesc() {} + + // SE_DISALLOW_COPY_AND_ASSIGN + MLUCnnlTensorDesc(const MLUCnnlTensorDesc& desc) = delete; + MLUCnnlTensorDesc& operator=(const MLUCnnlTensorDesc&) = delete; + + MLUCnnlTensorDesc(MLUCnnlTensorDesc&& rhs) + : raw_tensor_desc(rhs.raw_tensor_desc) { + rhs.raw_tensor_desc = nullptr; + } + + MLUCnnlTensorDesc& operator=(MLUCnnlTensorDesc&& rhs); + + MLUCnnlTensorDesc(const int tensor_dim, const int dim_sizes[], + const cnnlDataType_t tensor_dtype); + + MLUCnnlTensorDesc(const int tensor_dim, const int dim_sizes[], + const cnnlDataType_t tensor_dtype, + const cnnlTensorLayout_t layout); + + MLUCnnlTensorDesc(const int tensor_dim, const int dim_sizes[], + const cnnlDataType_t tensor_dtype, int position); + + MLUCnnlTensorDesc(const int tensor_dim, const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype); + + MLUCnnlTensorDesc(const int tensor_dim, const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype, + const cnnlTensorLayout_t layout); + + MLUCnnlTensorDesc(const int tensor_dim, const int64_t dim_sizes[], + const cnnlDataType_t tensor_dtype, int position); + + MLUCnnlTensorDesc(const Tensor& tensor, const cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype); + + MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype, int position); + + MLUCnnlTensorDesc(const Tensor& tensor, cnnlTensorLayout_t layout, + const cnnlDataType_t tensor_dtype, int position, + float scale); + + ~MLUCnnlTensorDesc(); + + const cnnlTensorDescriptor_t get() const { return raw_tensor_desc; } + + private: + cnnlTensorDescriptor_t raw_tensor_desc = nullptr; +}; + +class MLUCnnlActivationDesc { + public: + MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete; + MLUCnnlActivationDesc& operator=(const MLUCnnlActivationDesc& desc) = delete; + MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof); + + const cnnlActivationDescriptor_t get() const; + ~MLUCnnlActivationDesc(); + + private: + cnnlActivationDescriptor_t active_desc_ = nullptr; +}; + +class MLUCnnl { + public: + static void Active(const platform::MLUDeviceContext& ctx, + cnnlActivationDescriptor_t active_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t output_desc, void* output); + + static void ActiveGrad(const platform::MLUDeviceContext& ctx, + cnnlActivationDescriptor_t active_desc, + const void* alpha, const void* beta, + const cnnlTensorDescriptor_t y_desc, const void* y, + const cnnlTensorDescriptor_t diff_y_desc, + const void* diff_y, + const cnnlTensorDescriptor_t x_desc, const void* x, + const cnnlTensorDescriptor_t diff_x_desc, + void* diff_x); +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 728c6af1812ff912e0756a4267e6eb097a98a0e1..26bf5d8b1be9d2807ee4aca28426d09d967ee438 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -35,6 +35,7 @@ set(enforce_deps flags errors boost flags) if(WITH_GPU) set(enforce_deps ${enforce_deps} external_error_proto) endif() + cc_library(enforce INTERFACE SRCS enforce.cc DEPS ${enforce_deps}) cc_library(monitor SRCS monitor.cc) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece enforce) @@ -82,13 +83,17 @@ IF(WITH_ASCEND_CL) set(NPU_CTX_DEPS npu_stream npu_info) ENDIF() +IF(WITH_MLU) + set(MLU_CTX_DEPS mlu_device_context) +ENDIF() + IF(WITH_MKLDNN) set(MKLDNN_CTX_DEPS mkldnn) ELSE() set(MKLDNN_CTX_DEPS) ENDIF() -IF(WITH_ASCEND_CL) +IF(WITH_ASCEND_CL OR WITH_MLU) cc_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) ENDIF() @@ -117,7 +122,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS}) + ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS}) cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce) if(WITH_ASCEND_CL) diff --git a/paddle/fluid/platform/device/CMakeLists.txt b/paddle/fluid/platform/device/CMakeLists.txt index 0cd07dec20e3ed545c605a5257052ebbbbb13b57..c5fe211470949e4023bb28c6f0e9d1888d66a1fc 100644 --- a/paddle/fluid/platform/device/CMakeLists.txt +++ b/paddle/fluid/platform/device/CMakeLists.txt @@ -15,3 +15,8 @@ ENDIF() IF(WITH_IPU) add_subdirectory(ipu) ENDIF() + +# MLU +IF(WITH_MLU) + add_subdirectory(mlu) +ENDIF() diff --git a/paddle/fluid/platform/device/mlu/CMakeLists.txt b/paddle/fluid/platform/device/mlu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ef4439f39b6a553e83747452b32d6dd6a2e999b --- /dev/null +++ b/paddle/fluid/platform/device/mlu/CMakeLists.txt @@ -0,0 +1,10 @@ +IF(WITH_MLU) + cc_test(mlu_enforce_test SRCS enforce_test.cc DEPS stringpiece) + + cc_library(mlu_info SRCS mlu_info.cc DEPS enforce glog monitor neuware_lib) + + cc_library(mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_manager) + + cc_library(mlu_device_context SRCS device_context.cc DEPS mlu_stream ) + cc_test(mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context) +ENDIF() diff --git a/paddle/fluid/platform/device/mlu/device_context.cc b/paddle/fluid/platform/device/mlu/device_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..40bdde2653e3c4b0e7296b9d477568baec110e7c --- /dev/null +++ b/paddle/fluid/platform/device/mlu/device_context.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif + +namespace paddle { +namespace platform { + +#ifdef PADDLE_WITH_MLU +thread_local std::unordered_map> + MLUDeviceContext::thread_ctx_; +thread_local std::mutex MLUDeviceContext::ctx_mtx_; + +MLUContext::MLUContext(const MLUPlace& place, const int priority) { + place_ = place; + MLUDeviceGuard guard(place_.device); + stream_.reset(new stream::MLUStream(place_, priority)); + InitCNNLContext(); +} + +MLUContext::~MLUContext() { + MLUDeviceGuard guard(place_.device); + DestoryCNNLContext(); +} + +MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { + MLUDeviceGuard guard(place_.device); + compute_capability_ = GetMLUComputeCapability(place_.device); + driver_version_ = GetMLUDriverVersion(place_.device); + runtime_version_ = GetMLURuntimeVersion(place_.device); + + LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device + << ", MLU Compute Capability: " + << compute_capability_ / 10 << "." + << compute_capability_ % 10 + << ", Driver API Version: " << driver_version_ / 10000 + << "." << (driver_version_ / 100) % 100 << "." + << driver_version_ % 100 << ", Runtime API Version: " + << runtime_version_ / 10000 << "." + << (runtime_version_ / 100) % 100 << "." + << runtime_version_ % 100; + + default_ctx_.reset(new MLUContext(place_)); +} + +MLUDeviceContext::~MLUDeviceContext() {} + +Place MLUDeviceContext::GetPlace() const { return place_; } + +void MLUDeviceContext::Wait() const { context()->Stream()->Wait(); } + +int MLUDeviceContext::GetComputeCapability() const { + return compute_capability_; +} + +mluCnnlHandle MLUDeviceContext::cnnl_handle() const { + return context()->CnnlHandle(); +} + +mluStream MLUDeviceContext::stream() const { return context()->RawStream(); } + +#endif +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..67c547dc69a8dcfadefab0e947f28527be27744d --- /dev/null +++ b/paddle/fluid/platform/device/mlu/device_context.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#ifdef PADDLE_WITH_MLU +#include +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/device/mlu/mlu_stream.h" +#include "paddle/fluid/platform/device_context.h" + +namespace Eigen { +struct DefaultDevice; +struct GpuDevice; +} // namespace Eigen + +// class DeviceContext; + +namespace paddle { +namespace platform { + +class MLUContext { + public: + MLUContext() = default; + explicit MLUContext(const MLUPlace& place, const int priority = 0); + + ~MLUContext(); + + const MLUPlace& Place() const { return place_; } + + const std::unique_ptr& EigenDevice() const { + return eigen_device_; + } + + const std::unique_ptr& Stream() const { return stream_; } + + stream::MLUStream* SetStream(stream::MLUStream* new_stream_ptr) { + auto* old_stream_ptr = stream_.release(); + stream_.reset(new_stream_ptr); + return old_stream_ptr; + } + + const mluStream& RawStream() { return stream_->raw_stream(); } + + const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; } + + private: + void InitCNNLContext() { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream())); + } + + void DestoryCNNLContext() { + if (cnnl_handle_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_)); + } + cnnl_handle_ = nullptr; + } + + MLUPlace place_; + std::unique_ptr eigen_device_; + std::unique_ptr stream_; + mluCnnlHandle cnnl_handle_; + + DISABLE_COPY_AND_ASSIGN(MLUContext); +}; + +class MLUDeviceContext : public DeviceContext { + public: + explicit MLUDeviceContext(MLUPlace place); + virtual ~MLUDeviceContext(); + Eigen::DefaultDevice* eigen_device() const { return nullptr; } + Place GetPlace() const override; + + int GetComputeCapability() const; + + /*! \brief Wait for all operations completion in the stream. */ + void Wait() const override; + + /*! \brief Return cnnl handle in the device context. */ + mluCnnlHandle cnnl_handle() const; + + /*! \brief Return mlu stream in the device context. */ + mluStream stream() const; + + template + void RecordEvent(mluEventHandle ev, Callback callback) const { + return context()->Stream()->RecordEvent(ev, callback); + } + + template + void AddStreamCallback(Callback&& callback) const { + return context()->Stream()->AddCallback(callback); + } + + void WaitStreamCallback() const { + return context()->Stream()->WaitCallback(); + } + + void ResetDefaultContext(const int priority) { + default_ctx_.reset(new MLUContext(place_, priority)); + } + + void ResetThreadContext(const int priority) { + std::lock_guard guard(ctx_mtx_); + thread_ctx_[this].reset(new MLUContext(place_, priority)); + } + + std::shared_ptr context() const { + if (!thread_ctx_.count(this)) { + return default_ctx_; + } + return thread_ctx_.at(this); + } + + private: + int compute_capability_; + int driver_version_; + int runtime_version_; + MLUPlace place_; + std::shared_ptr default_ctx_; + + // The thread_local static variable will be released before the + // global static variable, so avoid using it in dtor. + static thread_local std::unordered_map> + thread_ctx_; + static thread_local std::mutex ctx_mtx_; + + DISABLE_COPY_AND_ASSIGN(MLUDeviceContext); +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = MLUDeviceContext; +}; + +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context_allocator.h b/paddle/fluid/platform/device/mlu/device_context_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..9deab92af5cd6d31121637202215a3008d0c594c --- /dev/null +++ b/paddle/fluid/platform/device/mlu/device_context_allocator.h @@ -0,0 +1,162 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { + +namespace platform { +class MLUDeviceContext; +} // namespace platform + +namespace memory { +namespace allocation { + +/** + * MLUDeviceContextAllocation is a wrapper of the underbeneath allocation. + * MLUDeviceContextAllocation adds a MLU stream callback for the underbeneath + * allocation so that MLUDeviceContextAllocation can be used in a MLU stream + * which deletes allocation in the callback. + */ +class MLUDeviceContextAllocation : public Allocation { + public: + explicit MLUDeviceContextAllocation(AllocationPtr allocation) + : Allocation(allocation->ptr(), allocation->size(), allocation->place()), + underlying_allocation_(std::move(allocation)) {} + + ~MLUDeviceContextAllocation() { + PADDLE_ENFORCE_NOT_NULL( + dev_ctx_, + platform::errors::PreconditionNotMet( + "Device context is not set for MLUDeviceContextAllocation")); + auto *p_allocation = underlying_allocation_.release(); + VLOG(4) << "Adding callback to delete MLUDeviceContextAllocation at " + << p_allocation; + dev_ctx_->AddStreamCallback([p_allocation] { + VLOG(4) << "Delete MLUDeviceContextAllocation at " << p_allocation; + AllocationDeleter()(p_allocation); + }); + } + + void SetMLUDeviceContext(const platform::MLUDeviceContext *dev_ctx) { + dev_ctx_ = dev_ctx; + } + + private: + AllocationPtr underlying_allocation_; + const platform::MLUDeviceContext *dev_ctx_{nullptr}; +}; + +/** + * MLUDeviceContextAllocator will allocate a MLUDeviceContextAllocation + * after waiting for a self-created event on the default stream. It does so to + * let the non-default stream be able to allocate GPU memory which will be + * released by stream callback + */ +class MLUDeviceContextAllocator : public Allocator { + public: + explicit MLUDeviceContextAllocator(platform::MLUPlace place, + mluStream default_stream) + : place_(place), default_stream_(default_stream) { + platform::MLUDeviceGuard guard(place_.device); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierCreate(&event_)); + } + + ~MLUDeviceContextAllocator() { + if (event_) { + platform::MLUDeviceGuard guard(place_.device); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierDestroy(event_)); + } + } + + protected: + Allocation *AllocateImpl(size_t size) override { + PADDLE_ENFORCE_NOT_NULL( + default_stream_, + platform::errors::PreconditionNotMet( + "Default stream is not set for MLUDeviceContextAllocator")); + platform::MLUDeviceGuard guard(place_.device); + auto allocation = + new MLUDeviceContextAllocation(memory::Alloc(place_, size)); + // Wait for the event on stream + PADDLE_ENFORCE_MLU_SUCCESS(cnrtPlaceNotifier(event_, default_stream_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtWaitNotifier(event_)); + return allocation; + } + + void FreeImpl(Allocation *allocation) override { delete allocation; } + + private: + platform::MLUPlace place_; + mluEventHandle event_{nullptr}; + mluStream default_stream_{nullptr}; +}; + +/** + * MLUDeviceContextAllocatorPool is a singletion stores mapping from + * MLUPlace(s) to std::shared_ptr. When a + * MLUDeviceContext's compute stream isn't default stream, it can call this + * class to allocate GPU memory which will be released by a callback after + * stream execution. + */ +class MLUDeviceContextAllocatorPool { + public: + static MLUDeviceContextAllocatorPool &Instance() { + static MLUDeviceContextAllocatorPool pool; + return pool; + } + + AllocationPtr Alloc(const platform::MLUDeviceContext &dev_ctx, size_t size) { + auto iter = allocators_.find( + BOOST_GET_CONST(platform::MLUPlace, dev_ctx.GetPlace())); + PADDLE_ENFORCE_NE( + iter, allocators_.end(), + platform::errors::NotFound("No allocator found for MLUPlace.")); + auto &allocator = iter->second; + AllocationPtr allocation = allocator->Allocate(size); + static_cast(allocation.get()) + ->SetMLUDeviceContext(&dev_ctx); + return allocation; + } + + private: + MLUDeviceContextAllocatorPool() { + std::vector devices = platform::GetMLUSelectedDevices(); + for (int i : devices) { + auto place = platform::MLUPlace(i); + auto compute_stream = + platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); + auto allocator = std::shared_ptr( + new MLUDeviceContextAllocator(place, compute_stream)); + allocators_.insert(make_pair(place, allocator)); + } + } + + std::map> + allocators_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/device_context_test.cc b/paddle/fluid/platform/device/mlu/device_context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5caaa9dec1e4bab6bb0faac1d299f4f1db0ab477 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/device_context_test.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/platform/device/mlu/device_context.h" + +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(Device, Init) { + using paddle::platform::DeviceContext; + using paddle::platform::MLUDeviceContext; + using paddle::platform::MLUPlace; + using paddle::platform::MLUContext; + + int count = paddle::platform::GetMLUDeviceCount(); + for (int i = 0; i < count; i++) { + MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); + std::shared_ptr ctx = device_context->context(); + ASSERT_NE(nullptr, ctx); + delete device_context; + } +} + +TEST(Device, MLUDeviceContext) { + using paddle::platform::MLUDeviceContext; + using paddle::platform::MLUPlace; + using paddle::mluCnnlHandle; + + int count = paddle::platform::GetMLUDeviceCount(); + for (int i = 0; i < count; i++) { + MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); + mluCnnlHandle mlu_handle = device_context->cnnl_handle(); + ASSERT_NE(nullptr, mlu_handle); + delete device_context; + } +} + +TEST(Device, MLUStream) { + using paddle::platform::MLUDeviceContext; + using paddle::platform::MLUPlace; + using paddle::mluStream; + + int count = paddle::platform::GetMLUDeviceCount(); + for (int i = 0; i < count; i++) { + MLUDeviceContext* device_context = new MLUDeviceContext(MLUPlace(i)); + mluStream mlu_stream = device_context->stream(); + ASSERT_NE(nullptr, mlu_stream); + delete device_context; + } +} + +TEST(Device, DeviceContextPool) { + using paddle::platform::DeviceContextPool; + using paddle::platform::MLUDeviceContext; + using paddle::platform::Place; + using paddle::platform::CPUPlace; + using paddle::platform::MLUPlace; + + DeviceContextPool& pool = DeviceContextPool::Instance(); + auto cpu_dev_ctx1 = pool.Get(CPUPlace()); + auto cpu_dev_ctx2 = pool.Get(CPUPlace()); + ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); + + std::vector mlu_places; + int count = paddle::platform::GetMLUDeviceCount(); + for (int i = 0; i < count; ++i) { + auto dev_ctx = pool.Get(MLUPlace(i)); + ASSERT_NE(dev_ctx, nullptr); + } +} diff --git a/paddle/fluid/platform/device/mlu/enforce.h b/paddle/fluid/platform/device/mlu/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..eecbad53cab93719e327b9ff9f567233f044cf5f --- /dev/null +++ b/paddle/fluid/platform/device/mlu/enforce.h @@ -0,0 +1,143 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif // PADDLE_WITH_MLU + +#ifdef PADDLE_WITH_MLU +DECLARE_int64(gpu_allocator_retry_time); +#endif + +namespace paddle { +namespace platform { + +#ifdef PADDLE_WITH_MLU +namespace details { +template +struct MLUStatusType {}; + +#define DEFINE_MLU_STATUS_TYPE(type, success_value, proto_type) \ + template <> \ + struct MLUStatusType { \ + using Type = type; \ + static constexpr Type kSuccess = success_value; \ + static constexpr const char* kTypeString = #proto_type; \ + } + +DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); +DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); +DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); + +} // namespace details + +/*************** CNRT ERROR ***************/ +inline bool is_error(cnrtStatus e) { return e != cnrtSuccess; } + +inline std::string build_mlu_error_msg(cnrtStatus e) { + std::ostringstream sout; + sout << "MLU CNRT error(" << e << "), " << cnrtGetErrorName(e) << ": " + << cnrtGetErrorStr(e); + return sout.str(); +} + +/*************** CNNL ERROR ***************/ +inline bool is_error(cnnlStatus stat) { return stat != CNNL_STATUS_SUCCESS; } + +inline std::string build_mlu_error_msg(cnnlStatus stat) { + std::ostringstream sout; + sout << "MLU CNNL error(" << stat << "), " << cnnlGetErrorString(stat) + << ". "; + return sout.str(); +} + +/*************** CN API ERROR ***************/ +inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; } + +inline std::string build_mlu_error_msg(cnStatus stat) { + const char* error_name; + const char* error_string; + cnGetErrorName(stat, &error_name); + cnGetErrorString(stat, &error_string); + + std::ostringstream sout; + sout << "MLU CN error(" << static_cast(stat) << "), " << error_name + << " : " << error_string << ". "; + return sout.str(); +} + +#define PADDLE_ENFORCE_MLU_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + using __MLU_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::MLUStatusType< \ + __MLU_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + ::paddle::platform::build_mlu_error_msg(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + +#define PADDLE_ENFORCE_MLU_LAUNCH_SUCCESS(OP) \ + do { \ + auto res = cnrtGetLastError(); \ + if (UNLIKELY(res != cnrtSuccess)) { \ + auto msg = ::paddle::platform::build_mlu_error_msg(res); \ + PADDLE_THROW(platform::errors::Fatal("CNRT error after kernel (%s): %s", \ + OP, msg)); \ + } \ + } while (0) + +inline void retry_sleep(unsigned milliseconds) { + if (milliseconds < 1000) { + // usleep argument must be less than 1,000,000. Reference: + // https://pubs.opengroup.org/onlinepubs/7908799/xsh/usleep.html + usleep(milliseconds * 1000); + } else { + // clip to sleep in seconds because we can not and don't have to + // sleep for exact milliseconds + sleep(milliseconds / 1000); + } +} + +#define PADDLE_RETRY_MLU_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + int retry_count = 1; \ + using __MLU_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::MLUStatusType< \ + __MLU_STATUS_TYPE__>::kSuccess; \ + while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \ + retry_sleep(FLAGS_gpu_allocator_retry_time); \ + __cond__ = (COND); \ + ++retry_count; \ + } \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + ::paddle::platform::build_mlu_error_msg(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + +#undef DEFINE_MLU_STATUS_TYPE +#endif // PADDLE_WITH_MLU + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/enforce_test.cc b/paddle/fluid/platform/device/mlu/enforce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7241afba6aa52eabeaf3293641f4abf907a66b1b --- /dev/null +++ b/paddle/fluid/platform/device/mlu/enforce_test.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/device/mlu/enforce.h" + +#include + +#include "gtest/gtest.h" + +#ifdef PADDLE_WITH_MLU +template +bool CheckMluStatusSuccess(T value, const std::string& msg = "success") { + PADDLE_ENFORCE_MLU_SUCCESS(value); + return true; +} + +template +bool CheckMluStatusFailure(T value, const std::string& msg) { + try { + PADDLE_ENFORCE_MLU_SUCCESS(value); + return false; + } catch (paddle::platform::EnforceNotMet& error) { + std::string ex_msg = error.what(); + std::cout << ex_msg << std::endl; + return ex_msg.find(msg) != std::string::npos; + } +} + +TEST(mlu_enforce, mlu_success) { + EXPECT_TRUE(CheckMluStatusSuccess(cnrtSuccess)); + EXPECT_TRUE(CheckMluStatusFailure(cnrtErrorArgsInvalid, "invalid argument")); + EXPECT_TRUE(CheckMluStatusFailure(cnrtErrorMemcpyDirectionInvalid, + "invalid memcpy direction")); + EXPECT_TRUE( + CheckMluStatusFailure(cnrtErrorDeviceInvalid, "invalid device ordinal")); + + EXPECT_TRUE(CheckMluStatusSuccess(CNNL_STATUS_SUCCESS)); + EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_NOT_INITIALIZED, "CNNL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_ALLOC_FAILED, "CNNL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_BAD_PARAM, "CNNL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNNL_STATUS_INTERNAL_ERROR, "CNNL error")); + + EXPECT_TRUE(CheckMluStatusSuccess(CN_SUCCESS)); + EXPECT_TRUE(CheckMluStatusFailure( + CN_ERROR_NOT_READY, + "Asynchronous operations issued previously not completed yet")); + EXPECT_TRUE( + CheckMluStatusFailure(CN_ERROR_NOT_INITIALIZED, "initialization error")); + EXPECT_TRUE( + CheckMluStatusFailure(CN_ERROR_INVALID_VALUE, "invalid argument")); + EXPECT_TRUE(CheckMluStatusFailure(CN_MEMORY_ERROR_OUT_OF_MEMORY, + "device has no memory to alloc")); +} +#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_info.cc b/paddle/fluid/platform/device/mlu/mlu_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..7cad99bf5d22df39590d94bd56571d914d2b0193 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_info.cc @@ -0,0 +1,426 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/lock_guard_ptr.h" +#include "paddle/fluid/platform/monitor.h" +#include "paddle/fluid/string/split.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); +DECLARE_uint64(initial_gpu_memory_in_mb); +DECLARE_uint64(reallocate_gpu_memory_in_mb); +DECLARE_uint64(gpu_memory_limit_mb); + +constexpr static float fraction_reserve_mlu_memory = 0.05f; + +PADDLE_DEFINE_EXPORTED_string( + selected_mlus, "", + "A list of device ids separated by comma, like: 0,1,2,3. " + "This option is useful when doing multi process training and " + "each process have only one device (MLU). If you want to use " + "all visible devices, set this to empty string. NOTE: the " + "reason of doing this is that we want to use P2P communication" + "between MLU devices, use MLU_VISIBLE_DEVICES can only use" + "share-memory only."); + +USE_MLU_MEM_STAT; +namespace paddle { +namespace platform { + +static int GetMLUDeviceCountImpl() { + int x, y, z; + // When cnrtDriverGetVersion is executed, the device is initialized, + // no longer needs to call cnrtInit(). + cnrtStatus stat = cnrtDriverGetVersion(&x, &y, &z); + if (stat != cnrtSuccess) { + VLOG(2) << "MLU Driver Version can't be detected. No MLU driver!"; + return 0; + } + + const auto *mlu_visible_devices = std::getenv("MLU_VISIBLE_DEVICES"); + if (mlu_visible_devices != nullptr) { + std::string mlu_visible_devices_str(mlu_visible_devices); + if (std::all_of(mlu_visible_devices_str.begin(), + mlu_visible_devices_str.end(), + [](char ch) { return ch == ' '; })) { + VLOG(2) << "MLU_VISIBLE_DEVICES is set to be " + "empty. No MLU detected."; + return 0; + } + } + + int count; + PADDLE_ENFORCE_MLU_SUCCESS(cnDeviceGetCount(&count)); + return count; +} + +int GetMLUDeviceCount() { + static auto dev_cnt = GetMLUDeviceCountImpl(); + return dev_cnt; +} + +std::vector GetMLUSelectedDevices() { + // use user specified MLUs in single-node multi-process mode. + std::vector devices; + if (!FLAGS_selected_mlus.empty()) { + auto devices_str = paddle::string::Split(FLAGS_selected_mlus, ','); + for (auto id : devices_str) { + devices.push_back(atoi(id.c_str())); + } + } else { + int count = GetMLUDeviceCount(); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + return devices; +} + +void CheckDeviceId(int id) { + PADDLE_ENFORCE_LT(id, GetMLUDeviceCount(), + platform::errors::InvalidArgument( + "Device id must be less than MLU count, " + "but received id is: %d. MLU count is: %d.", + id, GetMLUDeviceCount())); +} + +int GetMLUDriverVersion(int id) { + CheckDeviceId(id); + int x, y, z; + PADDLE_ENFORCE_MLU_SUCCESS(cnrtDriverGetVersion(&x, &y, &z)); + return x * 10000 + y * 100 + z; +} + +int GetMLURuntimeVersion(int id) { + CheckDeviceId(id); + int x, y, z; + PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetLibVersion(&x, &y, &z)); + return x * 10000 + y * 100 + z; +} + +int GetMLUCurrentDeviceId() { + int device_id; + PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); + return device_id; +} + +void SetMLUDeviceId(int id) { + CheckDeviceId(id); + PADDLE_RETRY_MLU_SUCCESS(cnrtSetDevice(id)); +} + +void GetMLUDeviceHandle(int device_ordinal, mluDeviceHandle *device) { + cnStatus res = cnDeviceGet(device, device_ordinal); + if (res != CN_SUCCESS) { + VLOG(2) << "failed to get handle of MLU Device."; + } + PADDLE_ENFORCE_MLU_SUCCESS(res); +} + +int GetMLUComputeCapability(int id) { + CheckDeviceId(id); + mluDeviceHandle device; + GetMLUDeviceHandle(id, &device); + + int major, minor; + cnStatus major_stat = cnDeviceGetAttribute( + &major, CN_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); + cnStatus minor_stat = cnDeviceGetAttribute( + &minor, CN_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); + PADDLE_ENFORCE_MLU_SUCCESS(major_stat); + PADDLE_ENFORCE_MLU_SUCCESS(minor_stat); + + return major * 10 + minor; +} + +void MLUMemoryUsage(size_t *available, size_t *total) { + size_t actual_available, actual_total; + RecordedMLUMemGetInfo(available, total, &actual_available, &actual_total, + platform::GetMLUCurrentDeviceId()); +} + +size_t MLUAvailableMemToAlloc() { + size_t total = 0; + size_t available = 0; + MLUMemoryUsage(&available, &total); + size_t reserving = + static_cast(fraction_reserve_mlu_memory * available); + // If available size is less than minimum chunk size, no usable memory exists + size_t available_to_alloc = available - reserving; + size_t min_chunk_size = MLUMinChunkSize(); + if (available_to_alloc < min_chunk_size) { + available_to_alloc = 0; + } + VLOG(10) << "MLU usage " << ((total - available) >> 20) << "M/" + << (total >> 20) << "M, " << (available_to_alloc >> 20) + << "M available to allocate"; + return available_to_alloc; +} + +size_t MLUMaxAllocSize() { + return std::max(MLUInitAllocSize(), MLUReallocSize()); +} + +static size_t MLUAllocSize(bool realloc) { + size_t available_to_alloc = MLUAvailableMemToAlloc(); + PADDLE_ENFORCE_GT( + available_to_alloc, 0, + platform::errors::ResourceExhausted("Not enough available MLU memory.")); + // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be + // allocated by fraction + size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb + : FLAGS_initial_gpu_memory_in_mb; + size_t alloc_bytes = + (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc * + FLAGS_fraction_of_gpu_memory_to_use); + PADDLE_ENFORCE_GE( + available_to_alloc, alloc_bytes, + platform::errors::ResourceExhausted("Not enough available MLU memory.")); + VLOG(10) << "Alloc size is " << (alloc_bytes >> 20) + << " MiB, is it Re-alloc: " << realloc; + return alloc_bytes; +} + +size_t MLUInitAllocSize() { return MLUAllocSize(/* realloc = */ false); } + +size_t MLUReallocSize() { return MLUAllocSize(/* realloc = */ true); } + +size_t MLUMinChunkSize() { + // Allow to allocate the minimum chunk size is 256 bytes. + return 1 << 8; +} + +size_t MLUMaxChunkSize() { + size_t max_chunk_size = MLUMaxAllocSize(); + VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M"; + return max_chunk_size; +} + +void MLUMemcpyD2HAsync(void *dst, const void *src, size_t num, + mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync(dst, const_cast(src), num, + stream, cnrtMemcpyDevToHost)); +} + +void MLUMemcpyD2HSync(void *dst, const void *src, size_t num) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyDevToHost)); +} + +void MLUMemcpyH2DAsync(void *dst, const void *src, size_t num, + mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync(dst, const_cast(src), num, + stream, cnrtMemcpyHostToDev)); +} +void MLUMemcpyH2DSync(void *dst, const void *src, size_t num) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyHostToDev)); +} + +void MLUMemcpyD2DAsync(void *dst, const void *src, size_t num, + mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyAsync(dst, const_cast(src), num, + stream, cnrtMemcpyDevToDev)); +} +void MLUMemcpyD2DSync(void *dst, const void *src, size_t num) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnrtMemcpy(dst, const_cast(src), num, cnrtMemcpyDevToDev)); +} + +void MLUMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t num, mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyPeerAsync( + dst, dst_device, const_cast(src), src_device, num, stream)); +} + +void MLUMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t num) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemcpyPeer( + dst, dst_device, const_cast(src), src_device, num)); +} + +void MLUMemsetAsync(void *dst, int value, size_t count, mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtMemsetAsync(dst, value, count, stream)); +} + +void MLUStreamSync(mluStream stream) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream)); +} + +static void RaiseNonOutOfMemoryError(cnrtStatus *status) { + if (*status == cnrtErrorNoMem) { + *status = cnrtSuccess; + } + PADDLE_ENFORCE_MLU_SUCCESS(*status); + + *status = cnrtGetLastError(); + if (*status == cnrtErrorNoMem) { + *status = cnrtSuccess; + } + PADDLE_ENFORCE_MLU_SUCCESS(*status); +} + +class RecordedMLUMallocHelper { + private: + explicit RecordedMLUMallocHelper(int dev_id, uint64_t limit_size = 0) + : dev_id_(dev_id), limit_size_(limit_size) { + if (NeedRecord()) { + mtx_.reset(new std::mutex()); + } + } + + DISABLE_COPY_AND_ASSIGN(RecordedMLUMallocHelper); + + public: + static RecordedMLUMallocHelper *Instance(int dev_id) { + std::call_once(once_flag_, [] { + int dev_cnt = GetMLUDeviceCount(); + instances_.reserve(dev_cnt); + for (int i = 0; i < dev_cnt; ++i) { + instances_.emplace_back( + new RecordedMLUMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20)); + } + }); + + PADDLE_ENFORCE_GE( + dev_id, 0, + platform::errors::OutOfRange( + "Device id must be not less than 0, but got %d.", dev_id)); + PADDLE_ENFORCE_LT( + dev_id, instances_.size(), + platform::errors::OutOfRange("Device id %d exceeds mlu card number %d.", + dev_id, instances_.size())); + return instances_[dev_id].get(); + } + + /** + * Try to allocate `size` mlu memory. Only cnrtErrorNoMem + * or cnrtSuccess would be returned, and the cnrtGetLastError() flag + * would be clear. + */ + cnrtStatus Malloc(void **ptr, size_t size) { + LockGuardPtr lock(mtx_); + if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) { + return cnrtErrorNoMem; + } + + MLUDeviceGuard guard(dev_id_); + auto result = cnrtMalloc(ptr, size); + if (result == cnrtSuccess) { + cur_size_.fetch_add(size); + STAT_INT_ADD("STAT_mlu" + std::to_string(dev_id_) + "_mem_size", size); + return cnrtSuccess; + } else { + RaiseNonOutOfMemoryError(&result); + // Non out of memory error would be raised inside + // RaiseNonOutOfMemoryError. + // Therefore, we can return cnrtErrorNoMem directly here. + return cnrtErrorNoMem; + } + } + + /** + * Free mlu memory. Usually, free is not allowed to raise error. + * If it does raise error, the process should be crashed. + */ + void Free(void *ptr, size_t size) { + MLUDeviceGuard guard(dev_id_); + auto err = cnrtFree(ptr); + PADDLE_ENFORCE_MLU_SUCCESS(err); + if (NeedRecord()) { + cur_size_.fetch_sub(size); + } + STAT_INT_SUB("STAT_mlu" + std::to_string(dev_id_) + "_mem_size", size); + } + + bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total) { + { + MLUDeviceGuard guard(dev_id_); + auto result = cnrtMemGetInfo(actual_avail, actual_total); + if (result != cnrtSuccess) { + *actual_avail = 0; + } + RaiseNonOutOfMemoryError(&result); + } + + if (NeedRecord()) { + std::lock_guard guard(*mtx_); + *avail = std::min(*actual_avail, limit_size_ - cur_size_.load()); + *total = std::min(*actual_total, limit_size_); + return *total < *actual_total; + } else { + *avail = *actual_avail; + *total = *actual_total; + return false; + } + } + + inline bool NeedRecord() const { return limit_size_ != 0; } + + uint64_t RecordedSize() const { return cur_size_.load(); } + + uint64_t LimitSize() const { return limit_size_; } + + private: + const int dev_id_; + const uint64_t limit_size_; + std::atomic cur_size_{0}; + + mutable std::unique_ptr mtx_; + + static std::once_flag once_flag_; + static std::vector> instances_; +}; // NOLINT + +std::once_flag RecordedMLUMallocHelper::once_flag_; +std::vector> + RecordedMLUMallocHelper::instances_; + +cnrtStatus RecordedMLUMalloc(void **ptr, size_t size, int dev_id) { + return RecordedMLUMallocHelper::Instance(dev_id)->Malloc(ptr, size); +} + +void RecordedMLUFree(void *p, size_t size, int dev_id) { + return RecordedMLUMallocHelper::Instance(dev_id)->Free(p, size); +} + +bool RecordedMLUMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total, int dev_id) { + return RecordedMLUMallocHelper::Instance(dev_id)->GetMemInfo( + avail, total, actual_avail, actual_total); +} + +uint64_t RecordedMLUMallocSize(int dev_id) { + return RecordedMLUMallocHelper::Instance(dev_id)->RecordedSize(); +} + +bool IsMLUMallocRecorded(int dev_id) { + return RecordedMLUMallocHelper::Instance(dev_id)->NeedRecord(); +} + +void EmptyCache(void) { + std::vector devices = GetMLUSelectedDevices(); + for (auto device : devices) { + memory::Release(MLUPlace(device)); + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..4588dd66677200106d0e6cf1d6a868d0e1b52c90 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -0,0 +1,160 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_MLU +#include +#include +#include +#include + +namespace paddle { + +using cnStatus = CNresult; +using cnrtStatus = cnrtRet_t; +using cnnlStatus = cnnlStatus_t; +using mluStream = cnrtQueue_t; +using mluCnnlHandle = cnnlHandle_t; +using mluEventHandle = CNnotifier; +using mluDeviceHandle = CNdev; + +namespace platform { + +//! Get the driver version of the ith MLU. +int GetMLUDriverVersion(int id); + +//! Get the runtime version of the ith MLU. +int GetMLURuntimeVersion(int id); + +//! Get the total number of MLU devices in system. +int GetMLUDeviceCount(); + +//! Get a list of device ids from environment variable or use all. +std::vector GetMLUSelectedDevices(); + +//! Get the current MLU device id in system. +int GetMLUCurrentDeviceId(); + +//! Set the MLU device id for next execution. +void SetMLUDeviceId(int device_id); + +//! Get a handle of device ids. +void GetMLUDeviceHandle(int device_ordinal, mluDeviceHandle* device); + +//! Get the compute capability of the ith MLU (format: major * 10 + minor) +int GetMLUComputeCapability(int id); + +//! Get the memory usage of current MLU device. +void MLUMemoryUsage(size_t* available, size_t* total); + +//! Get the available memory to allocate, which is the size of available mlu +//! minus reserving. +size_t MLUAvailableMemToAlloc(); + +//! Get the maximum allocation size of current MLU device. +size_t MLUMaxAllocSize(); + +//! Get the initial allocation size of current MLU device. +size_t MLUInitAllocSize(); + +//! Get the re-allocation size of current MLU device. +size_t MLUReallocSize(); + +//! Get the minimum chunk size for MLU buddy allocator. +size_t MLUMinChunkSize(); + +//! Get the maximum chunk size for MLU buddy allocator. +size_t MLUMaxChunkSize(); + +//! Copy memory from address device to host asynchronously. +void MLUMemcpyD2HAsync(void* dst, const void* src, size_t num, + mluStream stream); + +//! Copy memory from address device to host synchronously. +void MLUMemcpyD2HSync(void* dst, const void* src, size_t num); + +//! Copy memory from address host to device asynchronously. +void MLUMemcpyH2DAsync(void* dst, const void* src, size_t num, + mluStream stream); + +//! Copy memory from address host to device synchronously. +void MLUMemcpyH2DSync(void* dst, const void* src, size_t num); + +//! Copy memory from address device to device asynchronously in a single device. +void MLUMemcpyD2DAsync(void* dst, const void* src, size_t num, + mluStream stream); + +//! Copy memory from address device to device synchronously in a single device. +void MLUMemcpyD2DSync(void* dst, const void* src, size_t num); + +//! Copy memory from one device to another device asynchronously. +void MLUMemcpyPeerAsync(void* dst, int dst_place, const void* src, + int src_place, size_t num, mluStream stream); + +//! Copy memory from one device to another device synchronously. +void MLUMemcpyPeerSync(void* dst, int dst_place, const void* src, int src_place, + size_t num); + +//! Set memory dst with value count size asynchronously +void MLUMemsetAsync(void* dst, int value, size_t count, mluStream stream); + +//! Blocks until stream has completed all operations. +void MLUStreamSync(mluStream stream); + +//! MLUMalloc with recorded info +cnrtStatus RecordedMLUMalloc(void** ptr, size_t size, int dev_id); + +//! MLUFree with recorded info +void RecordedMLUFree(void* p, size_t size, int dev_id); + +//! Get available and total mlu memory with considering limitation +bool RecordedMLUMemGetInfo(size_t* avail, size_t* total, size_t* actual_avail, + size_t* actual_total, int dev_id); + +//! Get recorded mluMalloc size. If record is disabled, return 0. +uint64_t RecordedMLUMallocSize(int dev_id); + +bool IsMLUMallocRecorded(int dev_id); + +//! Empty idle cached memory held by the allocator. +void EmptyCache(void); + +class MLUDeviceGuard { + public: + explicit inline MLUDeviceGuard(int dev_id) { + int prev_id = platform::GetMLUCurrentDeviceId(); + if (prev_id != dev_id) { + prev_id_ = prev_id; + platform::SetMLUDeviceId(dev_id); + } + } + + inline ~MLUDeviceGuard() { + if (prev_id_ != -1) { + platform::SetMLUDeviceId(prev_id_); + } + } + + MLUDeviceGuard(const MLUDeviceGuard& o) = delete; + MLUDeviceGuard& operator=(const MLUDeviceGuard& o) = delete; + + private: + int prev_id_{-1}; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_stream.cc b/paddle/fluid/platform/device/mlu/mlu_stream.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a27a49250a1ee9e58a2f76ce902b167e7aeb027 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_stream.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/device/mlu/mlu_stream.h" +#include "paddle/fluid/platform/device/mlu/device_context.h" + +namespace paddle { +namespace platform { +namespace stream { + +bool MLUStream::Init(const MLUPlace& place, const int priority) { + PADDLE_ENFORCE_EQ(is_mlu_place(place), true, + platform::errors::InvalidArgument( + "MLU stream must be created using mlu place.")); + place_ = place; + MLUDeviceGuard guard(place_.device); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueCreate(&stream_)); + callback_manager_.reset(new StreamCallbackManager(stream_)); + VLOG(3) << "MLUStream Init stream: " << stream_; + return true; +} + +void MLUStream::Destroy() { + MLUDeviceGuard guard(place_.device); + Wait(); + WaitCallback(); + if (stream_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueDestroy(stream_)); + } + stream_ = nullptr; +} + +void MLUStream::Wait() const { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); +} + +MLUStream* get_current_mlu_stream(int deviceId) { +#ifdef PADDLE_WITH_MLU + if (deviceId == -1) { + deviceId = platform::GetMLUCurrentDeviceId(); + } + auto& pool = platform::DeviceContextPool::Instance(); + platform::Place device = MLUPlace(deviceId); + auto stream = static_cast(pool.Get(device)) + ->context() + ->Stream() + .get(); + return stream; +#else + PADDLE_THROW(platform::errors::Unavailable( + "Paddle is not compiled with MLU. Cannot visit mlu current stream.")); + return nullptr; +#endif +} + +MLUStream* set_current_mlu_stream(MLUStream* stream) { +#ifdef PADDLE_WITH_MLU + auto& device = stream->GetPlace(); + auto& pool = platform::DeviceContextPool::Instance(); + return static_cast(pool.Get(device)) + ->context() + ->SetStream(stream); +#else + PADDLE_THROW(platform::errors::Unavailable( + "Paddle is not compiled with MLU. Cannot visit mlu current stream.")); + return nullptr; +#endif +} +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/mlu_stream.h b/paddle/fluid/platform/device/mlu/mlu_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..3f4b27e370f2e729c84cf8d5a9ccdefb6d1a4e1e --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_stream.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/stream_callback_manager.h" + +namespace paddle { +namespace platform { +namespace stream { + +#ifdef PADDLE_WITH_MLU +class MLUStream final { + public: + MLUStream() = default; + explicit MLUStream(const MLUPlace& place, const int priority = 0) { + Init(place, priority); + } + virtual ~MLUStream() { Destroy(); } + + bool Init(const MLUPlace& place, const int priority = 0); + + template + void AddCallback(Callback&& callback) const { + // TODO(mlu): mlu not support AddCallback + callback_manager_->AddCallback(callback); + } + + template + void RecordEvent(mluEventHandle event, Callback callback) const { + callback(); + PADDLE_ENFORCE_MLU_SUCCESS(cnPlaceNotifier(event, stream_)); + } + + void RecordEvent(mluEventHandle event) const { + PADDLE_ENFORCE_MLU_SUCCESS(cnPlaceNotifier(event, stream_)); + } + + void WaitEvent(mluEventHandle event) const { + PADDLE_ENFORCE_MLU_SUCCESS(cnWaitNotifier(event)); + } + + void Wait() const; + void WaitCallback() const { callback_manager_->Wait(); } + + const mluStream& raw_stream() const { return stream_; } + + void Destroy(); + + bool Query() const { + cnrtStatus stat = cnrtQueueQuery(stream_); + if (stat == cnrtSuccess) { + return true; + } + if (stat == cnrtErrorNotReady) { + return false; + } + PADDLE_ENFORCE_MLU_SUCCESS(stat); + return false; + } + + void Synchronize() const { + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); + } + + const MLUPlace& GetPlace() const { return place_; } + + private: + MLUPlace place_; + mluStream stream_{nullptr}; + int priority_{0}; + std::unique_ptr> callback_manager_; + + DISABLE_COPY_AND_ASSIGN(MLUStream); +}; + +MLUStream* get_current_mlu_stream(int deviceId); +MLUStream* set_current_mlu_stream(MLUStream* stream); + +#endif + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 206bef12aac95e0a111e15afbd1a0533e913e7e9..60442eb4a0e9e25eea40ac1b937693de6da5820a 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#include "paddle/fluid/platform/device/mlu/device_context_allocator.h" +#endif #ifdef PADDLE_WITH_IPU #include "paddle/fluid/platform/ipu/ipu_backend.h" #endif @@ -56,6 +60,23 @@ AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) { PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use XPU device since it's not compiled with XPU," "Please recompile or reinstall Paddle with XPU support.")); +#endif + } else if (platform::is_mlu_place(place)) { +#ifdef PADDLE_WITH_MLU + auto* default_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto& desired_dev_ctx = + static_cast(dev_ctx); + if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { + return Alloc(place, size); + } else { + return allocation::MLUDeviceContextAllocatorPool::Instance().Alloc( + desired_dev_ctx, size); + } +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use MLU device since it's not compiled with MLU," + "Please recompile or reinstall Paddle with MLU support.")); #endif } else { return Alloc(place, size); @@ -85,6 +106,8 @@ DeviceType Place2DeviceType(const platform::Place& place) { return platform::DeviceType::CUDA; } else if (platform::is_xpu_place(place)) { return platform::DeviceType::XPU; + } else if (platform::is_mlu_place(place)) { + return platform::DeviceType::MLU; } else { PADDLE_THROW(platform::errors::Unavailable( "Unsupported place %s to convert into platform::DeviceType.", place)); @@ -99,7 +122,8 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { if (it == device_contexts_.end()) { PADDLE_THROW(platform::errors::Unimplemented( "Place %s is not supported. Please check that your paddle compiles " - "with WITH_GPU, WITH_XPU, WITH_IPU or WITH_ASCEND_CL option or check " + "with WITH_GPU, WITH_XPU, WITH_IPU, WITH_MLU or WITH_ASCEND_CL option " + "or check " "that your train process set the correct device id if you use " "Executor.", place)); @@ -162,6 +186,14 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( platform::errors::Unimplemented("XPUPlace is not supported. Please " "re-compile with WITH_XPU option.")); +#endif + } else if (platform::is_mlu_place(p)) { +#ifdef PADDLE_WITH_MLU + EmplaceDeviceContext(&device_contexts_, p); +#else + PADDLE_THROW( + platform::errors::Unimplemented("MLUPlace is not supported. Please " + "re-compile with WITH_MLU option.")); #endif } else if (platform::is_ipu_place(p)) { #ifdef PADDLE_WITH_IPU diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 875132dfe89c4b19faa1b86a65ecc81296129214..72fa525040b5f71d1ecf50fb391d8a4a6b1a8ab0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -103,7 +103,9 @@ enum DeviceType { XPU = 2, NPU = 3, IPU = 4, - MAX_DEVICE_TYPES = 5, + MLU = 5, + + MAX_DEVICE_TYPES = 6, }; DeviceType Place2DeviceType(const platform::Place& place); @@ -113,6 +115,7 @@ constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kXPU = DeviceType::XPU; constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kIPU = DeviceType::IPU; +constexpr DeviceType kMLU = DeviceType::MLU; class DeviceContext { public: @@ -165,7 +168,13 @@ template <> struct DefaultDeviceContextType { using TYPE = IPUDeviceContext; }; +#endif +#ifdef PADDLE_WITH_MLU +class MLUDeviceContext; + +template <> +struct DefaultDeviceContextType; #endif #ifdef PADDLE_WITH_XPU diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index a674a6a8acdf205f23bfcfcf53f6dbcb054723d9..2df3d00dc924a2cebd03f3cd01b34b2c14dfa58d 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -434,7 +434,7 @@ PADDLE_DEFINE_EXPORTED_double( // NOTE(zhiqiu): better to share the flags, otherwise we will have too many // flags. #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADDLE_WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) /** * Memory related FLAG @@ -662,8 +662,9 @@ PADDLE_DEFINE_EXPORTED_bool(conv2d_disable_cudnn, false, * Example: * Note: Get host by name time. */ -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU) || \ - defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_XPU) || \ + defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_MLU) PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120, "The maximum time for get host by name time"); #endif diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index b642f160da21a539558567a1a30ad9bb23167aba..e9d2f8e901e8db2f5ed8e5c292d91e7a7a05c7a9 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -34,6 +34,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif + #ifdef WITH_WIN_DUMP_DBG #include #include @@ -177,6 +181,14 @@ void InitDevices() { LOG(WARNING) << "Compiled with PADDLE_WITH_IPU, but no IPU found in runtime."; } +#endif +#ifdef PADDLE_WITH_MLU + try { + // use user specified MLUs in single-node multi-process mode. + devices = platform::GetMLUSelectedDevices(); + } catch (const std::exception &exp) { + LOG(WARNING) << "Compiled with WITH_MLU, but no MLU found in runtime."; + } #endif InitDevices(devices); } @@ -203,6 +215,9 @@ void InitDevices(const std::vector devices) { #endif #ifdef PADDLE_WITH_ASCEND_CL places.emplace_back(platform::NPUPlace(devices[i])); +#endif +#ifdef PADDLE_WITH_MLU + places.emplace_back(platform::MLUPlace(devices[i])); #endif } places.emplace_back(platform::CPUPlace()); diff --git a/paddle/fluid/platform/init_test.cc b/paddle/fluid/platform/init_test.cc index dbca7d154954617d9c603ecff5a10858777f8879..5301dd307590b25d457d658b4468998fb71137b0 100644 --- a/paddle/fluid/platform/init_test.cc +++ b/paddle/fluid/platform/init_test.cc @@ -14,13 +14,16 @@ limitations under the License. */ #include "paddle/fluid/platform/init.h" #include "gtest/gtest.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif TEST(InitDevices, CPU) { using paddle::framework::InitDevices; using paddle::platform::DeviceContextPool; #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU) && \ - !defined(PADDLE_WITH_HIP) + !defined(PADDLE_WITH_HIP) && !defined(PADDLE_WITH_MLU) InitDevices(); DeviceContextPool& pool = DeviceContextPool::Instance(); ASSERT_EQ(pool.size(), 1U); @@ -51,6 +54,18 @@ TEST(InitDevices, XPU) { #endif } +TEST(InitDevices, MLU) { + using paddle::framework::InitDevices; + using paddle::platform::DeviceContextPool; + +#ifdef PADDLE_WITH_MLU + int count = paddle::platform::GetMLUDeviceCount(); + InitDevices(); + DeviceContextPool& pool = DeviceContextPool::Instance(); + ASSERT_EQ(pool.size(), 1U + static_cast(count)); +#endif +} + #ifndef _WIN32 TEST(SignalHandle, SignalHandle) { std::string msg = "Signal raises"; diff --git a/paddle/fluid/platform/monitor.cc b/paddle/fluid/platform/monitor.cc index 1b44cb196547c2d26cdd5ae72c3331022f834657..ea6240b649cad07732b637aeb7273ab922ee326f 100644 --- a/paddle/fluid/platform/monitor.cc +++ b/paddle/fluid/platform/monitor.cc @@ -45,3 +45,21 @@ DEFINE_INT_STATUS(STAT_npu4_mem_size) DEFINE_INT_STATUS(STAT_npu5_mem_size) DEFINE_INT_STATUS(STAT_npu6_mem_size) DEFINE_INT_STATUS(STAT_npu7_mem_size) + +// For MLU +DEFINE_INT_STATUS(STAT_mlu0_mem_size) +DEFINE_INT_STATUS(STAT_mlu1_mem_size) +DEFINE_INT_STATUS(STAT_mlu2_mem_size) +DEFINE_INT_STATUS(STAT_mlu3_mem_size) +DEFINE_INT_STATUS(STAT_mlu4_mem_size) +DEFINE_INT_STATUS(STAT_mlu5_mem_size) +DEFINE_INT_STATUS(STAT_mlu6_mem_size) +DEFINE_INT_STATUS(STAT_mlu7_mem_size) +DEFINE_INT_STATUS(STAT_mlu8_mem_size) +DEFINE_INT_STATUS(STAT_mlu9_mem_size) +DEFINE_INT_STATUS(STAT_mlu10_mem_size) +DEFINE_INT_STATUS(STAT_mlu11_mem_size) +DEFINE_INT_STATUS(STAT_mlu12_mem_size) +DEFINE_INT_STATUS(STAT_mlu13_mem_size) +DEFINE_INT_STATUS(STAT_mlu14_mem_size) +DEFINE_INT_STATUS(STAT_mlu15_mem_size) diff --git a/paddle/fluid/platform/monitor.h b/paddle/fluid/platform/monitor.h index 0eb9448ce0fad4e1caadb3e08140417294d5d0e7..dc9abaf36d8250f0cfe82b3a37b6d3759826f475 100644 --- a/paddle/fluid/platform/monitor.h +++ b/paddle/fluid/platform/monitor.h @@ -197,3 +197,21 @@ class StatRegistry { USE_INT_STAT(STAT_npu5_mem_size); \ USE_INT_STAT(STAT_npu6_mem_size); \ USE_INT_STAT(STAT_npu7_mem_size) + +#define USE_MLU_MEM_STAT \ + USE_INT_STAT(STAT_mlu0_mem_size); \ + USE_INT_STAT(STAT_mlu1_mem_size); \ + USE_INT_STAT(STAT_mlu2_mem_size); \ + USE_INT_STAT(STAT_mlu3_mem_size); \ + USE_INT_STAT(STAT_mlu4_mem_size); \ + USE_INT_STAT(STAT_mlu5_mem_size); \ + USE_INT_STAT(STAT_mlu6_mem_size); \ + USE_INT_STAT(STAT_mlu7_mem_size); \ + USE_INT_STAT(STAT_mlu8_mem_size); \ + USE_INT_STAT(STAT_mlu9_mem_size); \ + USE_INT_STAT(STAT_mlu10_mem_size); \ + USE_INT_STAT(STAT_mlu11_mem_size); \ + USE_INT_STAT(STAT_mlu12_mem_size); \ + USE_INT_STAT(STAT_mlu13_mem_size); \ + USE_INT_STAT(STAT_mlu14_mem_size); \ + USE_INT_STAT(STAT_mlu15_mem_size) diff --git a/paddle/fluid/platform/place.cc b/paddle/fluid/platform/place.cc index ec49134b654e93199c3ef522e5afda6337ff3db8..6251a28823ac3b42371a6cfbfdf641061f2e713b 100644 --- a/paddle/fluid/platform/place.cc +++ b/paddle/fluid/platform/place.cc @@ -34,6 +34,7 @@ class PlacePrinter : public boost::static_visitor<> { os_ << "CUDAPlace(" << p.device << ")"; } void operator()(const XPUPlace &p) { os_ << "XPUPlace(" << p.device << ")"; } + void operator()(const MLUPlace &p) { os_ << "MLUPlace(" << p.device << ")"; } void operator()(const NPUPlace &p) { os_ << "NPUPlace(" << p.device << ")"; } void operator()(const NPUPinnedPlace &p) { os_ << "NPUPinnedPlace"; } void operator()(const IPUPlace &p) { os_ << "IPUPlace(" << p.device << ")"; } @@ -53,6 +54,10 @@ bool is_xpu_place(const Place &p) { return boost::apply_visitor(IsXPUPlace(), p); } +bool is_mlu_place(const Place &p) { + return boost::apply_visitor(IsMLUPlace(), p); +} + bool is_npu_place(const Place &p) { return boost::apply_visitor(IsNPUPlace(), p); } @@ -83,6 +88,8 @@ bool is_same_place(const Place &p1, const Place &p2) { return true; } else if (is_xpu_place(p1)) { return BOOST_GET_CONST(XPUPlace, p1) == BOOST_GET_CONST(XPUPlace, p2); + } else if (is_mlu_place(p1)) { + return BOOST_GET_CONST(MLUPlace, p1) == BOOST_GET_CONST(MLUPlace, p2); } else if (is_npu_place(p1)) { return BOOST_GET_CONST(NPUPlace, p1) == BOOST_GET_CONST(NPUPlace, p2); } else if (is_ipu_place(p1)) { diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index fadc1e27e8a0ac6116b4d99cc6bc4adbdfbd3907..886eb05813bd85a313a75146191af34f12a81566 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -108,11 +108,25 @@ struct IPUPlace { int device; }; +struct MLUPlace { + MLUPlace() : MLUPlace(0) {} + explicit MLUPlace(int d) : device(d) {} + + inline int GetDeviceId() const { return device; } + // needed for variant equality comparison + inline bool operator==(const MLUPlace &o) const { return device == o.device; } + inline bool operator!=(const MLUPlace &o) const { return !(*this == o); } + inline bool operator<(const MLUPlace &o) const { return device < o.device; } + + int device; +}; + struct IsCUDAPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return true; } bool operator()(const CUDAPinnedPlace &) const { return false; } @@ -123,6 +137,7 @@ struct IsCPUPlace : public boost::static_visitor { bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } @@ -133,6 +148,7 @@ struct IsCUDAPinnedPlace : public boost::static_visitor { bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; } @@ -143,6 +159,7 @@ struct IsXPUPlace : public boost::static_visitor { bool operator()(const XPUPlace &) const { return true; } bool operator()(const NPUPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } @@ -153,6 +170,7 @@ struct IsNPUPlace : public boost::static_visitor { bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return true; } bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } @@ -162,32 +180,48 @@ struct IsNPUPinnedPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return true; } }; + +struct IsMLUPlace : public boost::static_visitor { + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const XPUPlace &) const { return false; } + bool operator()(const NPUPlace &) const { return false; } + bool operator()(const NPUPinnedPlace &) const { return false; } + bool operator()(const MLUPlace &) const { return true; } + bool operator()(const IPUPlace &) const { return false; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &) const { return false; } +}; struct IsIPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const XPUPlace &) const { return false; } bool operator()(const NPUPlace &) const { return false; } bool operator()(const IPUPlace &) const { return true; } + bool operator()(const MLUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } bool operator()(const NPUPinnedPlace &) const { return false; } }; class Place : public boost::variant { + CUDAPinnedPlace, NPUPinnedPlace, IPUPlace, + MLUPlace> { private: - using PlaceBase = boost::variant; + using PlaceBase = + boost::variant; public: Place() = default; Place(const CPUPlace &cpu_place) : PlaceBase(cpu_place) {} // NOLINT Place(const XPUPlace &xpu_place) : PlaceBase(xpu_place) {} // NOLINT Place(const NPUPlace &npu_place) : PlaceBase(npu_place) {} // NOLINT + Place(const MLUPlace &mlu_place) : PlaceBase(mlu_place) {} // NOLINT Place(const IPUPlace &ipu_place) : PlaceBase(ipu_place) {} // NOLINT Place(const CUDAPlace &cuda_place) : PlaceBase(cuda_place) {} // NOLINT Place(const CUDAPinnedPlace &cuda_pinned_place) // NOLINT @@ -208,6 +242,7 @@ using PlaceList = std::vector; bool is_gpu_place(const Place &); bool is_xpu_place(const Place &); bool is_npu_place(const Place &); +bool is_mlu_place(const Place &); bool is_ipu_place(const Place &); bool is_cpu_place(const Place &); bool is_cuda_pinned_place(const Place &); @@ -257,6 +292,16 @@ struct PlaceVisitorWrapper return typename Visitor::result_type(); #endif } + + typename Visitor::result_type operator()(const MLUPlace &mlu) const { +#ifdef PADDLE_WITH_MLU + return visitor_(mlu); +#else + PADDLE_THROW(platform::errors::Unavailable( + "Paddle is not compiled with MLU. Cannot visit mlu device")); +#endif + } + typename Visitor::result_type operator()(const IPUPlace &ipu) const { #ifdef PADDLE_WITH_IPU return visitor_(ipu); diff --git a/paddle/fluid/platform/place_test.cc b/paddle/fluid/platform/place_test.cc index 41e084efa57004c3935e8a9f4200c1e5a4e8f664..ba19f14fb8f870a6f00baa407424eb5c8701fcb1 100644 --- a/paddle/fluid/platform/place_test.cc +++ b/paddle/fluid/platform/place_test.cc @@ -19,6 +19,7 @@ TEST(Place, Equality) { paddle::platform::CPUPlace cpu; paddle::platform::CUDAPlace g0(0), g1(1), gg0(0); paddle::platform::XPUPlace x0(0), x1(1), xx0(0); + paddle::platform::MLUPlace m0(0), m1(1), mm0(0); EXPECT_EQ(cpu, cpu); EXPECT_EQ(g0, g0); @@ -27,9 +28,13 @@ TEST(Place, Equality) { EXPECT_EQ(x0, x0); EXPECT_EQ(x1, x1); EXPECT_EQ(x0, xx0); + EXPECT_EQ(m0, m0); + EXPECT_EQ(m1, m1); + EXPECT_EQ(m0, mm0); EXPECT_NE(g0, g1); EXPECT_NE(x0, x1); + EXPECT_NE(m0, m1); EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0)); EXPECT_TRUE(paddle::platform::places_are_same_class(x0, xx0)); @@ -44,6 +49,11 @@ TEST(Place, Print) { ss << paddle::platform::XPUPlace(1); EXPECT_EQ("XPUPlace(1)", ss.str()); } + { + std::stringstream ss; + ss << paddle::platform::MLUPlace(1); + EXPECT_EQ("MLUPlace(1)", ss.str()); + } { std::stringstream ss; ss << paddle::platform::CUDAPlace(1); diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index 28aa022fe2f13280c80cbcadf72851e114301295..f6c54c2397b18f731ffa9ac44eca6a5dcaf18533 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -16,6 +16,10 @@ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/enforce.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif namespace paddle { namespace platform { @@ -36,6 +40,9 @@ static void StreamCallbackFunc(gpuStream_t stream, gpuError_t status, #if PADDLE_WITH_ASCEND_CL static void StreamCallbackFunc(void *user_data) #endif +#if PADDLE_WITH_MLU + static void StreamCallbackFunc(void *user_data) +#endif { std::unique_ptr> func( reinterpret_cast *>(user_data)); @@ -77,6 +84,13 @@ void StreamCallbackManager::AddCallback( // TODO(zhiqiu): failed to call aclrtLaunchCallback NPULaunchCallback(StreamCallbackFunc, func, ACL_CALLBACK_BLOCK, stream_); #endif + +#if PADDLE_WITH_MLU + VLOG(3) << "MLULaunchCallback at stream: " << stream_; + LOG(ERROR) << "failed to call MLULaunchCallback, " + << "because mlu not support StreamAddCallback yet. " + << "function: " << func; +#endif } template @@ -84,6 +98,9 @@ void StreamCallbackManager::Wait() const { #if defined(PADDLE_WITH_HIP) || defined(PADDLE_WITH_CUDA) platform::GpuStreamSync(stream_); #endif +#ifdef PADDLE_WITH_MLU + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(stream_)); +#endif #ifdef PADDLE_WITH_ASCEND_CL NPUStreamSync(stream_); #endif @@ -104,6 +121,9 @@ template struct StreamCallbackManager; #ifdef PADDLE_WITH_ASCEND_CL template struct StreamCallbackManager; #endif +#ifdef PADDLE_WITH_MLU +template struct StreamCallbackManager; +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index cedf123246268fde80876aabbdb706683553fae8..1d73ecbab5e54ed3d3e0418d5a5edc89b54756c1 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -133,10 +133,12 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) { return place_obj.cast(); } else if (py::isinstance(place_obj)) { return place_obj.cast(); + } else if (py::isinstance(place_obj)) { + return place_obj.cast(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Place should be one of " - "Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace")); + "Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/MLUPlace")); } } @@ -183,10 +185,13 @@ static void InitVarBaseAndTensor( } else if (platform::is_npu_place(place)) { SetTensorFromPyArray( tensor, array, BOOST_GET_CONST(platform::NPUPlace, place), zero_copy); + } else if (platform::is_mlu_place(place)) { + SetTensorFromPyArray( + tensor, array, BOOST_GET_CONST(platform::MLUPlace, place), zero_copy); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Place should be one of " - "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace")); + "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/MLUPlace")); } self->SetDataType(tensor->type()); } @@ -934,6 +939,10 @@ void BindImperative(py::module *m_ptr) { py::arg("value"), py::arg("place"), py::arg("persistable") = false, py::arg("zero_copy") = false, py::arg("name") = "", py::arg("stop_gradient") = -1) + .def("__init__", &InitVarBaseFromNumpyWithArg, + py::arg("value"), py::arg("place"), py::arg("persistable") = false, + py::arg("zero_copy") = false, py::arg("name") = "", + py::arg("stop_gradient") = -1) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"), py::arg("name") = "") @@ -947,6 +956,8 @@ void BindImperative(py::module *m_ptr) { py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromTensorWithArg, py::arg("tensor"), py::arg("place"), py::arg("name") = "") + .def("__init__", &InitVarBaseFromTensorWithArg, + py::arg("tensor"), py::arg("place"), py::arg("name") = "") .def("__init__", &InitVarBaseFromNumpyWithKwargs) .def( "__setitem_varbase__", @@ -1923,6 +1934,16 @@ void BindImperative(py::module *m_ptr) { return new_var; }, py::return_value_policy::copy) + .def("_copy_to", + [](const std::shared_ptr &self, + const platform::MLUPlace &place, bool blocking) { + auto new_var = self->NewVarBase(place, blocking); + if (!blocking) { + IncreaseVarbaseReferenceCountUntilCopyComplete(self, place); + } + return new_var; + }, + py::return_value_policy::copy) .def("_copy_to", [](const std::shared_ptr &self, const platform::Place &place, bool blocking) { @@ -2116,6 +2137,11 @@ void BindImperative(py::module *m_ptr) { self.SetExpectedPlace(*p); VLOG(4) << "Tracer(" << &self << ")" << " set expected place " << *p; + } else if (py::isinstance(obj)) { + auto p = obj.cast(); + self.SetExpectedPlace(*p); + VLOG(4) << "Tracer(" << &self << ")" + << " set expected place " << *p; } else if (py::isinstance(obj)) { auto p = obj.cast(); self.SetExpectedPlace(*p); @@ -2124,7 +2150,7 @@ void BindImperative(py::module *m_ptr) { } else { PADDLE_THROW(platform::errors::InvalidArgument( "Incompatible Place Type: supports XPUPlace, CUDAPlace, " - "CPUPlace, NPUPlace" + "CPUPlace, NPUPlace, MLUPlace" "and CUDAPinnedPlace, " "but got Unknown Type!")); } @@ -2198,6 +2224,19 @@ void BindImperative(py::module *m_ptr) { std::move(attrs), place, trace_backward); } }) + .def("trace", + [](imperative::Tracer &self, const std::string &type, + const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, + framework::AttributeMap attrs, const platform::MLUPlace &place, + bool trace_backward) { + auto ins_map = ConvertToNameVarBaseMap(ins); + auto outs_map = ConvertToNameVarBaseMap(outs); + { + py::gil_scoped_release release; + self.TraceOp(type, std::move(ins_map), std::move(outs_map), + std::move(attrs), place, trace_backward); + } + }) .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, @@ -2256,6 +2295,7 @@ void BindImperative(py::module *m_ptr) { m.def("varbase_copy", &VarBaseCopy); m.def("varbase_copy", &VarBaseCopy); m.def("varbase_copy", &VarBaseCopy); + m.def("varbase_copy", &VarBaseCopy); m.def( "dygraph_partial_grad", @@ -2397,6 +2437,11 @@ void BindImperative(py::module *m_ptr) { const py::args args, const py::kwargs kwargs) { return imperative::PyLayerApply(place, cls, args, kwargs); }); + m.def("pylayer_apply", + [](const platform::MLUPlace &place, const py::object &cls, + const py::args args, const py::kwargs kwargs) { + return imperative::PyLayerApply(place, cls, args, kwargs); + }); #if defined(PADDLE_WITH_CUDA) m.def(