diff --git a/CMakeLists.txt b/CMakeLists.txt index 59bc768aa41e1add945092b549e250508ff6716e..8d96c339dadc79baccb1668774259705e3046c72 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,11 +33,14 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF) option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF) option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF) option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF) +# NOTE(zhiqiu): WITH_ASCEND_CL can be compile on x86_64, so we can set WITH_ASCEND=OFF and WITH_ASCEND_CL=ON +# to develop some acl related functionality on x86 +option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND}) option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF) if (WITH_GPU AND WITH_XPU) message(FATAL_ERROR "Error when compile GPU and XPU at the same time") endif() -if (WITH_GPU AND WITH_ASCEND) +if (WITH_GPU AND WITH_ASCEND) message(FATAL_ERROR "Error when compile GPU and ASCEND at the same time") endif() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 2a1e6897c02e445b799815b0fdc498774e1f37ad..9f1eb16fcf03fc43fccb8370b5253a73b67b20ff 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -82,6 +82,10 @@ if(WITH_ASCEND) add_definitions(-DPADDLE_WITH_ASCEND) endif() +if(WITH_ASCEND_CL) + add_definitions(-DPADDLE_WITH_ASCEND_CL) +endif() + if(WITH_XPU) message(STATUS "Compile with XPU!") add_definitions(-DPADDLE_WITH_XPU) diff --git a/cmake/external/ascend.cmake b/cmake/external/ascend.cmake index a0b6f480f95ae70333c2f3dd8d20a8050b045425..bddd2023b437b1ce000f2f6a3e0e2fdd66647215 100644 --- a/cmake/external/ascend.cmake +++ b/cmake/external/ascend.cmake @@ -21,38 +21,60 @@ else() set(ASCEND_DIR /usr/local/Ascend) endif() -set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) -set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) -set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) -set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) -set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) -set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) -set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) - -set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) -set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) -set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) -set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include) -set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) -set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) -set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) - -set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so) -set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so) -set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) -INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) - -if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h) - add_definitions(-DPADDLE_WITH_ASCEND_STRING) +if(WITH_ASCEND) + set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) + set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) + set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) + set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) + set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) + set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) + set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) + + set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) + set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) + set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) + set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include) + set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) + set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) + set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) + + set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so) + set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so) + set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) + INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) + + if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h) + add_definitions(-DPADDLE_WITH_ASCEND_STRING) + endif() + + ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) + + ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib}) + + ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib}) + + add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl) endif() -ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) +if(WITH_ASCEND_CL) + set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) + + set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so) + set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so) + set(ASCEND_CL_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include) -ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib}) + message(STATUS "ASCEND_CL_INC_DIR ${ASCEND_CL_INC_DIR}") + message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}") + INCLUDE_DIRECTORIES(${ASCEND_CL_INC_DIR}) -ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib}) + ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib}) -add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl) + ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib}) + add_custom_target(extern_ascend_cl DEPENDS ascendcl acl_op_compiler) + +endif() diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 1466664c1266a74920e8834255ca71f5402500b1..82d64fd022883a75a6d334d3443dc43b4b06a904 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -201,6 +201,9 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11) SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git) SET(PROTOBUF_TAG v3.8.0) +elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11) + SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git) + SET(PROTOBUF_TAG v3.8.0) else() SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 0343ff3cc292d97dcc77108735baa69c804468af..7dac91e531e4cfd16fed211ef659350262dd3153 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -11,6 +11,7 @@ function(op_library TARGET) set(cu_cc_srcs) set(hip_cc_srcs) set(xpu_cc_srcs) + set(npu_cc_srcs) set(cudnn_cu_cc_srcs) set(miopen_cu_cc_srcs) set(cudnn_cu_srcs) @@ -20,6 +21,9 @@ function(op_library TARGET) set(mkldnn_cc_srcs) set(MKLDNN_FILE) set(op_common_deps operator op_registry math_function layer common_infer_shape_functions) + if (WITH_ASCEND_CL) + set(op_common_deps ${op_common_deps} npu_op_runner) + endif() # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build. set(options UNITY) set(oneValueArgs "") @@ -85,6 +89,12 @@ function(op_library TARGET) list(APPEND xpu_cc_srcs ${XPU_FILE}.cc) endif() endif() + if(WITH_ASCEND_CL) + string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc) + list(APPEND npu_cc_srcs ${NPU_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$") @@ -107,6 +117,8 @@ function(op_library TARGET) list(APPEND cu_cc_srcs ${src}) elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") list(APPEND xpu_cc_srcs ${src}) + elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") + list(APPEND npu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) else() @@ -176,7 +188,7 @@ function(op_library TARGET) # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`. if(WITH_UNITY_BUILD AND op_library_UNITY) # Combine the cc source files. - compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs}) + compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs}) if(TARGET ${UNITY_TARGET}) # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`. target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources}) @@ -187,7 +199,7 @@ function(op_library TARGET) # Add alias library to handle dependencies. add_library(${TARGET} ALIAS ${UNITY_TARGET}) else() - cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS} + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() endif() @@ -207,6 +219,7 @@ function(op_library TARGET) # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. # And for detail pybind information, please see generated paddle/pybind/pybind.h. + set(ORIGINAL_TARGET ${TARGET}) file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") # [ \t\r\n]* is used for blank characters @@ -239,8 +252,9 @@ function(op_library TARGET) list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) + list(LENGTH npu_cc_srcs npu_cc_srcs_len) if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND - ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0) + ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -280,6 +294,26 @@ function(op_library TARGET) if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n") endif() + + if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0) + file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT) + # It is different from the logic above, becareful + string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}") + # [ \t\r\n]* is used for blank characters + string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}") + + if (one_npu_register STREQUAL "") + string(REPLACE "_op" "" NPU_TARGET "${TARGET}") + else () + string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}") + string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}") + # [ \t\r\n]+ is used for blank characters. + # Here we use '+' instead of '*' since it is a REPLACE operation. + string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}") + endif() + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n") + endif() + # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) # Append first implemented MKLDNN activation operator @@ -330,6 +364,7 @@ function(register_operators) file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") string(REPLACE "_mkldnn" "" OPS "${OPS}") string(REPLACE "_xpu" "" OPS "${OPS}") + string(REPLACE "_npu" "" OPS "${OPS}") string(REPLACE ".cc" "" OPS "${OPS}") list(REMOVE_DUPLICATES OPS) list(LENGTH register_operators_DEPS register_operators_DEPS_len) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 6488d29afc5f7f4af72aab1cf2463d900a89fa9d..81fa7d0dfa98f0135ae18e9cf60036afcd76c745 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -274,10 +274,15 @@ if(WITH_BOX_PS) list(APPEND third_party_deps extern_box_ps) endif(WITH_BOX_PS) -if(WITH_ASCEND) +if(WITH_ASCEND OR WITH_ASCEND_CL) include(external/ascend) - list(APPEND third_party_deps extern_ascend) -endif (WITH_ASCEND) + if(WITH_ASCEND) + list(APPEND third_party_deps extern_ascend) + endif() + if(WITH_ASCEND_CL) + list(APPEND third_party_deps extern_ascend_cl) + endif() +endif () if (WITH_PSCORE) include(external/snappy) diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index a3fbb008fe4f444b7ad5b1fb3eb695ca4b4c7796..b99ab6b5a7ff195ef7d659598df88467bb158c6e 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -82,6 +82,11 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { platform::errors::Unimplemented("platform::XPUPlace is not supported")); } + inline ::DLContext operator()(const platform::NPUPlace &place) const { + PADDLE_THROW( + platform::errors::Unimplemented("platform::NPUPlace is not supported")); + } + inline ::DLContext operator()(const platform::CUDAPlace &place) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ::DLContext ctx; diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0acc8a55fa9f8a79c67b7beb732996c86f86ec5a..101991d2c1ba001915c2f029558d42e83fd2ddb6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -453,6 +453,14 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, #else PADDLE_THROW( platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle")); +#endif + } else if (platform::is_npu_place(place_)) { +#ifdef PADDLE_WITH_ASCEND_CL + // TODO(ascendrc): Support garbage collector on NPUPlace + VLOG(4) << "Skip NPU gc because it is not implemented now."; +#else + PADDLE_THROW(platform::errors::Unimplemented( + "No NPU gc found in CPU/GPU/XPU paddle")); #endif } } diff --git a/paddle/fluid/framework/garbage_collector.cc b/paddle/fluid/framework/garbage_collector.cc index c8b6c7642551756273ebecb83093e1d75d131f2c..8dfbd3c268b866e84192d6e4c86a76265bf651aa 100644 --- a/paddle/fluid/framework/garbage_collector.cc +++ b/paddle/fluid/framework/garbage_collector.cc @@ -86,8 +86,9 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream_)); #else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); + callback_manager_.reset( + new platform::StreamCallbackManager(stream_)); #endif - callback_manager_.reset(new platform::StreamCallbackManager(stream_)); } StreamGarbageCollector::~StreamGarbageCollector() { diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index 97800865af861f6598a3e74456deef1d0c355786..572c79d21a045b37058c6daf85ba559abf1e8e44 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -117,7 +117,8 @@ class StreamGarbageCollector : public GarbageCollector { private: gpuStream_t stream_; - std::unique_ptr callback_manager_; + std::unique_ptr> + callback_manager_; }; class CUDAPinnedGarbageCollector : public GarbageCollector { diff --git a/paddle/fluid/framework/library_type.h b/paddle/fluid/framework/library_type.h index 4307e51862df572e013431fceaaf89cc1cf6679c..8fe314cf5f18c5e8cc0a56ca8f231d32b9896aaf 100644 --- a/paddle/fluid/framework/library_type.h +++ b/paddle/fluid/framework/library_type.h @@ -61,6 +61,8 @@ inline LibraryType StringToLibraryType(const char* ctype) { return LibraryType::kPlain; } else if (s == std::string("XPU")) { return LibraryType::kPlain; + } else if (s == std::string("NPU")) { + return LibraryType::kPlain; } else if (s == std::string("CUDA")) { return LibraryType::kPlain; } else { diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 472c6f408266af6b47b7fdad2d1c9b3be6ee8cf5..4c529329761227f74e71efab0736d4abcdcc3c1e 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -304,6 +304,9 @@ struct OpKernelRegistrarFunctorEx &places, const BuildStrategy &build_strategy, ir::Graph *graph) : member_(new ParallelExecutorPrivate(places, scope)) { + PADDLE_ENFORCE(places.size() > 0 && !is_npu_place(places[0]), + platform::errors::Unavailable( + "NPU is not supported in ParallelExecutor")); InitP2P(places); ir::InitReaderQueueDeviceCount(graph, *(member_->global_scope_), member_->places_.size()); diff --git a/paddle/fluid/framework/tensor_test.cc b/paddle/fluid/framework/tensor_test.cc index 54f779813063362956130ec9314365f89a234c1e..101463756c0a5143536362c706ae08333673c831 100644 --- a/paddle/fluid/framework/tensor_test.cc +++ b/paddle/fluid/framework/tensor_test.cc @@ -125,25 +125,54 @@ TEST(Tensor, MutableData) { float* p2 = nullptr; // initialization p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); auto p1_holder = src_tensor.Holder(); EXPECT_NE(p1, nullptr); // set src_tensor a new dim with large size // momery is supposed to be re-allocated p2 = src_tensor.mutable_data(framework::make_ddim({3, 1024}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); auto p2_holder = src_tensor.Holder(); EXPECT_NE(p2, nullptr); EXPECT_NE(p1_holder.get(), p2_holder.get()); // set src_tensor a new dim with same size // momery block is supposed to be unchanged p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); EXPECT_EQ(p1, p2); // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + EXPECT_EQ(p1, p2); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + float* p1 = nullptr; + float* p2 = nullptr; + // initialization + p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}), + platform::NPUPlace(0)); + auto p1_holder = src_tensor.Holder(); + EXPECT_NE(p1, nullptr); + // set src_tensor a new dim with large size + // momery is supposed to be re-allocated + p2 = src_tensor.mutable_data(framework::make_ddim({3, 1024}), + platform::NPUPlace(0)); + auto p2_holder = src_tensor.Holder(); + EXPECT_NE(p2, nullptr); + EXPECT_NE(p1_holder.get(), p2_holder.get()); + // set src_tensor a new dim with same size + // momery block is supposed to be unchanged + p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), + platform::NPUPlace(0)); + EXPECT_EQ(p1, p2); + // set src_tensor a new dim with smaller size + // momery block is supposed to be unchanged + p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), + platform::NPUPlace(0)); EXPECT_EQ(p1, p2); } #endif @@ -179,7 +208,17 @@ TEST(Tensor, ShareDataWith) { framework::Tensor src_tensor; framework::Tensor dst_tensor; src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + dst_tensor.ShareDataWith(src_tensor); + ASSERT_EQ(src_tensor.data(), dst_tensor.data()); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + framework::Tensor dst_tensor; + src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), + platform::NPUPlace(0)); dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -216,7 +255,34 @@ TEST(Tensor, Slice) { { framework::Tensor src_tensor; src_tensor.mutable_data(framework::make_ddim({6, 9}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + framework::Tensor slice_tensor = src_tensor.Slice(2, 6); + framework::DDim slice_dims = slice_tensor.dims(); + ASSERT_EQ(arity(slice_dims), 2); + EXPECT_EQ(slice_dims[0], 4); + EXPECT_EQ(slice_dims[1], 9); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = + reinterpret_cast(src_tensor.mutable_data( + src_tensor.dims(), platform::CUDAPlace(0))); + uintptr_t slice_data_address = + reinterpret_cast(slice_tensor.data()); + uintptr_t slice_mutable_data_address = + reinterpret_cast(slice_tensor.mutable_data( + slice_tensor.dims(), platform::CUDAPlace(0))); + EXPECT_EQ(src_data_address, src_mutable_data_address); + EXPECT_EQ(slice_data_address, slice_mutable_data_address); + EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); + } +#endif + +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 9}), + platform::NPUPlace(0)); framework::Tensor slice_tensor = src_tensor.Slice(2, 6); framework::DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 2); @@ -227,12 +293,12 @@ TEST(Tensor, Slice) { reinterpret_cast(src_tensor.data()); uintptr_t src_mutable_data_address = reinterpret_cast(src_tensor.mutable_data( - src_tensor.dims(), platform::CUDAPlace())); + src_tensor.dims(), platform::NPUPlace(0))); uintptr_t slice_data_address = reinterpret_cast(slice_tensor.data()); uintptr_t slice_mutable_data_address = reinterpret_cast(slice_tensor.mutable_data( - slice_tensor.dims(), platform::CUDAPlace())); + slice_tensor.dims(), platform::NPUPlace(0))); EXPECT_EQ(src_data_address, src_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index c6ac30a369859db9de244990231a307074e973ed..d6882b25d22588a1f9fb3b663926350e237a5484 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -97,6 +97,42 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + // TODO(zhiqiu): handle different condition like CUDA code below + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, + stream); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, + stream); + } + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, + stream); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cuda_pinned_place(dst_place)) { @@ -304,6 +340,35 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, "Copy from %s to %s is not supported.", src_place, dst_place)); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_cpu_place(dst_place)) { /* npu -> cpu*/ + memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { /* cpu -> npu*/ + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size, + nullptr); + } + else if (platform::is_npu_place(src_place) && // NOLINT + platform::is_npu_place(dst_place)) { /* npu -> npu*/ + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data sync from " << src_place << " to " + << dst_place; + return; + } + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, src_place), src_ptr, size, + nullptr); + } + else { // NOLINT + PADDLE_THROW(platform::errors::Unimplemented( + "Copy from %s to %s is not supported.", src_place, dst_place)); + } +#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cuda_pinned_place(dst_place)) { @@ -431,6 +496,13 @@ class AnyVisitor : public boost::static_visitor { return GetResultHelper(out, gpu); } + bool GetResult(const framework::Tensor& out, + const platform::NPUPlace& npu) const { + PADDLE_THROW( + platform::errors::Unimplemented("Not supported on place (%s) ", npu)); + // return GetResultHelper(out, npu); + } + bool GetResult(const framework::Tensor& out, const platform::CPUPlace& cpu) const { return *out.data(); @@ -633,6 +705,10 @@ struct BothFalseVisitor : public boost::static_visitor<> { #endif } + void VisitorImpl(const platform::NPUPlace& npu) const { + // TODO(zhiqiu) + } + void VisitorImpl(const platform::CPUPlace& cpu) const { int num = in_.numel(); const bool* in_ptr = in_.data(); diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index fd0f98784ceb0ab9e744f8266b68cdafd98c2329..868d920f13ca8299fb53c0d0506d7e448fd152a3 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -157,6 +157,14 @@ void TensorFromVector(const std::vector& src, reinterpret_cast(ctx).stream()); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(dst_place)) { // NOLINT + memory::Copy( + BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place, + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } template @@ -194,6 +202,14 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, reinterpret_cast(ctx).stream()); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(src.place())) { // NOLINT + memory::Copy( + dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()), + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } template diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index df5ff750c9902f556f93283e584dbd52551b0fda..64f5a9e0cc8771305bcdb9796069ef76d8597802 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -115,6 +115,23 @@ class TensorAddFunctor : public boost::static_visitor<> { } #endif +#ifdef PADDLE_WITH_ASCEND_CL + void operator()(const platform::NPUPlace& place) { + // TODO(zhiqiu): SUPPORT it + PADDLE_THROW(platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#else + void operator()(const platform::NPUPlace& place) { + PADDLE_THROW(platform::errors::PermissionDenied( + "Gradient accumulation on place (%s) " + "is not supported in imperative mode", + place)); + } +#endif + // there is NO blas in CUDAPinnedPlace void operator()(const platform::CUDAPinnedPlace& place) { PADDLE_THROW(platform::errors::PermissionDenied( diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 565797d51dd513ac7fb44203f1e8d17955078c67..2ea047fa13c10596995916234ef67e8a276b6b22 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -27,6 +27,10 @@ if (WITH_ROCM) cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) endif() +if (WITH_ASCEND_CL) + cc_library(npu_allocator SRCS npu_allocator.cc DEPS allocator npu_info) +endif() + cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index cbeb263b5f41b96c73d67d9f56a407eecf209815..730efa5c646885026eee1e472205ce723b0fcb1b 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -32,6 +32,7 @@ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu_info.h" #endif +#include "paddle/fluid/platform/npu_info.h" DEFINE_int64( gpu_allocator_retry_time, 10000, @@ -66,6 +67,11 @@ class AllocatorFacadePrivate { InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id)); } InitNaiveBestFitCUDAPinnedAllocator(); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + for (int dev_id = 0; dev_id < platform::GetNPUDeviceCount(); ++dev_id) { + InitNaiveBestFitNPUAllocator(platform::NPUPlace(dev_id)); + } #endif break; } @@ -185,6 +191,12 @@ class AllocatorFacadePrivate { } #endif +#ifdef PADDLE_WITH_ASCEND_CL + void InitNaiveBestFitNPUAllocator(platform::NPUPlace p) { + allocators_[p] = std::make_shared(p); + } +#endif + class ZeroSizeAllocator : public Allocator { public: explicit ZeroSizeAllocator(platform::Place place) : place_(place) {} diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 0ada2cafcc16a638cba2e8dbd8d36ce1b219d0b5..3e88d61783c9e67053ef065f61fef5cf991a9b25 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -19,7 +19,10 @@ #include "gflags/gflags.h" #include "glog/logging.h" #include "paddle/fluid/memory/detail/buddy_allocator.h" +#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/npu_info.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/printf.h" @@ -110,6 +113,7 @@ size_t Used(const platform::CPUPlace &place) { return GetCPUBuddyAllocator()->Used(); } +// For kunlun XPU template <> void *Alloc(const platform::XPUPlace &place, size_t size) { #ifdef PADDLE_WITH_XPU @@ -219,6 +223,135 @@ size_t Used(const platform::XPUPlace &place) { #endif } +// For Ascend NPU +#ifdef PADDLE_WITH_ASCEND_CL +class NPUBuddyAllocatorList { + private: + NPUBuddyAllocatorList() : devices_(platform::GetSelectedNPUDevices()) { + auto npu_num = devices_.size(); + allocators_.resize(npu_num); + init_flags_.reserve(npu_num); + for (size_t i = 0; i < npu_num; ++i) { + init_flags_.emplace_back(new std::once_flag()); + } + } + + static NPUBuddyAllocatorList *CreateNewInstance() { + return new NPUBuddyAllocatorList(); + } + + public: + static NPUBuddyAllocatorList *Instance() { + static auto *instance = CreateNewInstance(); + return instance; + } + + BuddyAllocator *Get(int npu_id) { + auto pos = std::distance( + devices_.begin(), std::find(devices_.begin(), devices_.end(), npu_id)); + PADDLE_ENFORCE_LT(pos, devices_.size(), + platform::errors::OutOfRange( + "The index exceeds the size of devices, the size of " + "devices is %d, the index is %d", + devices_.size(), pos)); + + std::call_once(*init_flags_[pos], [this, pos] { + platform::SetNPUDeviceId(devices_[pos]); + allocators_[pos].reset(new BuddyAllocator( + std::unique_ptr( + new detail::NPUAllocator(devices_[pos])), + platform::NPUMinChunkSize(), platform::NPUMaxChunkSize())); + VLOG(10) << "\n\nNOTE:\n" + << "You can set GFlags environment variable " + << "'FLAGS_fraction_of_gpu_memory_to_use' " + << "or 'FLAGS_initial_gpu_memory_in_mb' " + << "or 'FLAGS_reallocate_gpu_memory_in_mb' " + << "to change the memory size for GPU usage.\n" + << "Current 'FLAGS_fraction_of_gpu_memory_to_use' value is " + << FLAGS_fraction_of_gpu_memory_to_use + << ". Current 'FLAGS_initial_gpu_memory_in_mb' value is " + << FLAGS_initial_gpu_memory_in_mb + << ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is " + << FLAGS_reallocate_gpu_memory_in_mb << "\n\n"; + }); + + return allocators_[pos].get(); + } + + private: + std::vector devices_; + std::vector> init_flags_; + std::vector> allocators_; +}; + +BuddyAllocator *GetNPUBuddyAllocator(int npu_id) { + return NPUBuddyAllocatorList::Instance()->Get(npu_id); +} +#endif + +template <> +size_t Used(const platform::NPUPlace &place) { +#ifdef PADDLE_WITH_ASCEND_CL + return GetNPUBuddyAllocator(place.device)->Used(); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'NPUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +void *Alloc(const platform::NPUPlace &place, size_t size) { +#ifdef PADDLE_WITH_ASCEND_CL + auto *buddy_allocator = GetNPUBuddyAllocator(place.device); + auto *ptr = buddy_allocator->Alloc(size); + if (ptr == nullptr) { + platform::NPUDeviceGuard(place.device); + size_t avail, total; + platform::NPUMemoryUsage(&avail, &total); + PADDLE_THROW(platform::errors::ResourceExhausted( + "Cannot allocate %s in GPU %d, avaliable %s, total %s, GpuMinChunkSize " + "%s, GpuMaxChunkSize %s, GPU memory used: %s.", + string::HumanReadableSize(size), place.device, + string::HumanReadableSize(avail), string::HumanReadableSize(total), + string::HumanReadableSize(buddy_allocator->GetMinChunkSize()), + string::HumanReadableSize(buddy_allocator->GetMaxChunkSize()), + string::HumanReadableSize(Used(place)))); + } else { + if (FLAGS_init_allocated_mem) { + aclrtMemset(ptr, size, 0xEF, size); + } + } + VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); + return ptr; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'NPUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +void Free(const platform::NPUPlace &place, void *p, + size_t size) { +#ifdef PADDLE_WITH_ASCEND_CL + VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); + GetNPUBuddyAllocator(place.device)->Free(p); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'NPUPlace' is not supported in CPU only device.")); +#endif +} + +template <> +uint64_t Release(const platform::NPUPlace &place) { +#ifdef PADDLE_WITH_ASCEND_CL + return GetNPUBuddyAllocator(place.device)->Release(); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "'NPUPlace' is not supported in CPU only device.")); +#endif +} + +// For CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class GPUBuddyAllocatorList { private: diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc index 37da748ee9c965ab829851c7045ad6a1a1e0e93e..1fe85dd699acf18387482d296c2c30f3bb2415cb 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc @@ -61,6 +61,22 @@ TEST(NaiveBestFitAllocatorTest, CudaPinnedAlloc) { } #endif +#ifdef PADDLE_WITH_ASCEND_CL +TEST(NaiveBestFitAllocatorTest, NpuAlloc) { + NaiveBestFitAllocator alloc{platform::NPUPlace(0)}; + { + size_t size = (1 << 20); + auto allocation = alloc.Allocate(size); + } + sleep(10); + alloc.Release(platform::NPUPlace(0)); + + size_t size = (1 << 20); + auto allocation = alloc.Allocate(size); + alloc.Release(platform::NPUPlace(0)); +} +#endif + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/npu_allocator.cc b/paddle/fluid/memory/allocation/npu_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ecdee9bd03352201060911848647b60d3cc0203 --- /dev/null +++ b/paddle/fluid/memory/allocation/npu_allocator.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/npu_allocator.h" +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/npu_info.h" + +namespace paddle { +namespace memory { +namespace allocation { + +bool NPUAllocator::IsAllocThreadSafe() const { return true; } +void NPUAllocator::FreeImpl(Allocation* allocation) { + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(platform::NPUPlace, allocation->place()), place_, + platform::errors::PermissionDenied( + "NPU memory is freed in incorrect device. This may be a bug")); + platform::RecordedNPUFree(allocation->ptr(), allocation->size(), + place_.device); + delete allocation; +} + +Allocation* NPUAllocator::AllocateImpl(size_t size) { + std::call_once(once_flag_, + [this] { platform::SetNPUDeviceId(place_.device); }); + + void* ptr; + auto result = platform::RecordedNPUMalloc(&ptr, size, place_.device); + if (LIKELY(result == ACL_ERROR_NONE)) { + return new Allocation(ptr, size, platform::Place(place_)); + } + + size_t avail, total, actual_avail, actual_total; + bool is_limited = platform::RecordedNPUMemGetInfo( + &avail, &total, &actual_avail, &actual_total, place_.device); + + std::string err_msg; + if (is_limited) { + auto limit_size = (total >> 20); + err_msg = string::Sprintf( + "Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger " + "value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum " + "GPU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size, limit_size); + } + + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on NPU %d. " + "Cannot allocate %s memory on NPU %d, " + "available memory is only %s.\n\n" + "Please check whether there is any other process using NPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another NPU.\n" + "2. If no, please decrease the batch size of your model. %s\n\n", + place_.device, string::HumanReadableSize(size), place_.device, + string::HumanReadableSize(avail), place_.device, err_msg)); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/npu_allocator.h b/paddle/fluid/memory/allocation/npu_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..738ec5d3ce120f3d08b887d3a84d4d79a1e9e1d6 --- /dev/null +++ b/paddle/fluid/memory/allocation/npu_allocator.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include // NOLINT +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class NPUAllocator : public Allocator { + public: + explicit NPUAllocator(const platform::NPUPlace& place) : place_(place) {} + + bool IsAllocThreadSafe() const override; + + protected: + void FreeImpl(Allocation* allocation) override; + Allocation* AllocateImpl(size_t size) override; + + private: + platform::NPUPlace place_; + std::once_flag once_flag_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/detail/CMakeLists.txt b/paddle/fluid/memory/detail/CMakeLists.txt index fcae741db3667f4acf9ff33323f3f95710724669..e9631ee739b9b8089a963a6aa84a9837010ad639 100644 --- a/paddle/fluid/memory/detail/CMakeLists.txt +++ b/paddle/fluid/memory/detail/CMakeLists.txt @@ -6,6 +6,8 @@ if(WITH_GPU) nv_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info place) elseif(WITH_ROCM) hip_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info place) +elseif(${WITH_ASCEND_CL}) + cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info npu_info place) else() cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info place) endif() diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 50c0b58f3a1dd6eafd4ca86f2378cbd8f4b2e041..55436f451a41ff2a77acddfaff3c5a7c290b7ac2 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -21,6 +21,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_uint64(reallocate_gpu_memory_in_mb); #endif +#ifdef PADDLE_WITH_ASCEND_CL +DECLARE_uint64(reallocate_gpu_memory_in_mb); +#endif namespace paddle { namespace memory { @@ -235,6 +238,21 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( } } #endif +#ifdef PADDLE_WITH_ASCEND_CL + if (system_allocator_->UseGpu()) { + if ((total_used_ + total_free_) == 0) { + // Compute the allocation size for gpu for the first allocation. + allocate_bytes = std::max(platform::NPUInitAllocSize(), request_bytes); + } else { + // Compute the re-allocation size, we store the re-allocation size when + // user set FLAGS_reallocate_gpu_memory_in_mb to fix value. + if (realloc_size_ == 0 || FLAGS_reallocate_gpu_memory_in_mb == 0ul) { + realloc_size_ = platform::NPUReallocSize(); + } + allocate_bytes = std::max(realloc_size_, request_bytes); + } + } +#endif // Allocate a new block void* p = system_allocator_->Alloc(&index, allocate_bytes); diff --git a/paddle/fluid/memory/detail/buddy_allocator.h b/paddle/fluid/memory/detail/buddy_allocator.h index 15e93deffccda8852b371a60ab3e08f9f8b811c2..135c3b6d04f346d361530ad5586e8f11e023d05c 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.h +++ b/paddle/fluid/memory/detail/buddy_allocator.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/npu_info.h" namespace paddle { namespace memory { diff --git a/paddle/fluid/memory/detail/buddy_allocator_test.cc b/paddle/fluid/memory/detail/buddy_allocator_test.cc index 2dc3e73af24162ebdc7872403fe28d83044920dc..290f3d5d1bcd47b40b8ee35ad45cd103bd11b26e 100644 --- a/paddle/fluid/memory/detail/buddy_allocator_test.cc +++ b/paddle/fluid/memory/detail/buddy_allocator_test.cc @@ -19,14 +19,16 @@ limitations under the License. */ #ifdef WITH_GPERFTOOLS #include "gperftools/profiler.h" #endif +#include +#include + #include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/npu_info.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include -#include - +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_ASCEND_CL) DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_uint64(initial_gpu_memory_in_mb); DECLARE_uint64(reallocate_gpu_memory_in_mb); @@ -342,6 +344,32 @@ TEST(BuddyAllocator, Release) { } #endif +#ifdef PADDLE_WITH_ASCEND_CL +TEST(BuddyAllocator, NpuFraction) { + // In a 16 GB machine, the pool size will be about 160 MB + FLAGS_fraction_of_gpu_memory_to_use = 0.005; + FLAGS_fraction_of_gpu_memory_to_use = 0.92; + FLAGS_initial_gpu_memory_in_mb = 0; + FLAGS_reallocate_gpu_memory_in_mb = 0; + + BuddyAllocator buddy_allocator( + std::unique_ptr(new NPUAllocator(0)), + platform::NPUMinChunkSize(), platform::NPUMaxChunkSize()); + + // Less than pool size + TestBuddyAllocator(&buddy_allocator, 10); + TestBuddyAllocator(&buddy_allocator, 10 << 10); + TestBuddyAllocator(&buddy_allocator, 10 << 20); + buddy_allocator.Release(); + + // Greater than max chunk size + TestBuddyAllocator(&buddy_allocator, 300 << 20, + /* use_system_allocator = */ true); + TestBuddyAllocator(&buddy_allocator, 1 * static_cast(1 << 30), + /* use_system_allocator = */ true); +} +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 38baf6c24bab3fb7ca55a15b4f231bf9eba7d82e..c733ba5c68c9bd8623acbc57bd248ebab449ef4c 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -29,6 +29,8 @@ limitations under the License. */ #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/npu_info.h" + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -247,6 +249,68 @@ bool CUDAPinnedAllocator::UseGpu() const { return false; } #endif +#ifdef PADDLE_WITH_ASCEND_CL +void* NPUAllocator::Alloc(size_t* index, size_t size) { + if (size <= 0) return nullptr; + + void* p; + auto result = platform::RecordedNPUMalloc(&p, size, npu_id_); + + if (result == ACL_ERROR_NONE) { + *index = 0; + npu_alloc_size_ += size; + return p; + } else { + size_t avail, total, actual_avail, actual_total; + bool is_limited = platform::RecordedNPUMemGetInfo( + &avail, &total, &actual_avail, &actual_total, npu_id_); + + std::string err_msg; + if (is_limited) { + auto limit_size = (total >> 20); + err_msg = string::Sprintf( + "\n 3) Set environment variable `FLAGS_gpu_memory_limit_mb` to a " + "larger value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the " + "maximum GPU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size, limit_size); + } + + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on NPU %d. " + "Cannot allocate %s memory on NPU %d, " + "available memory is only %s.\n\n" + "Please check whether there is any other process using NPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another NPU.\n" + "2. If no, please try one of the following suggestions:\n" + " 1) Decrease the batch size of your model.\n" + " 2) FLAGS_fraction_of_gpu_memory_to_use is %.2lf now, " + "please set it to a higher value but less than 1.0.\n" + " The command is " + "`export FLAGS_fraction_of_gpu_memory_to_use=xxx`.%s\n\n", + npu_id_, string::HumanReadableSize(size), npu_id_, + string::HumanReadableSize(avail), npu_id_, + FLAGS_fraction_of_gpu_memory_to_use, err_msg)); + } +} + +void NPUAllocator::Free(void* p, size_t size, size_t index) { + VLOG(4) << "Free " << p << " size " << size; + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The index should be 0, index is %d", index)); + PADDLE_ENFORCE_GE(npu_alloc_size_, size, + platform::errors::InvalidArgument( + "The size of memory (%d) to free exceeds the size of " + "allocated gpu memory (%d)", + size, npu_alloc_size_)); + npu_alloc_size_ -= size; + + platform::RecordedNPUFree(p, size, npu_id_); +} + +bool NPUAllocator::UseGpu() const { return true; } +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index e332bb670da2357f5ed831e743c20579677b90a5..26711ae4070f5ed72f77519b196c4c354cb049e1 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -66,6 +66,22 @@ class CUDAPinnedAllocator : public SystemAllocator { }; #endif +#ifdef PADDLE_WITH_ASCEND_CL + +class NPUAllocator : public SystemAllocator { + public: + explicit NPUAllocator(int npu_id) : npu_id_(npu_id) {} + + virtual void* Alloc(size_t* index, size_t size); + virtual void Free(void* p, size_t size, size_t index); + virtual bool UseGpu() const; + + private: + size_t npu_alloc_size_ = 0; + int npu_id_; +}; +#endif + } // namespace detail } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/detail/system_allocator_test.cc b/paddle/fluid/memory/detail/system_allocator_test.cc index 13854d771a0bf60bfef90515795ee70d9cb7fb73..ead188341dac46bd3eec490015ff934dc8a26af5 100644 --- a/paddle/fluid/memory/detail/system_allocator_test.cc +++ b/paddle/fluid/memory/detail/system_allocator_test.cc @@ -85,3 +85,11 @@ TEST(GPUAllocator, AllocFailure) { } } #endif + +#ifdef PADDLE_WITH_ASCEND_CL +TEST(NPUAllocator, Alloc) { + paddle::memory::detail::NPUAllocator a(0); + TestAllocator(&a, 1 << 20); + TestAllocator(&a, 1); +} +#endif diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 6f252e1bd0de7232bd1f51e0d3588bf0293f7962..d9a4503cc1e5f7cebbac9062f38739b02c64890b 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -196,6 +196,85 @@ void Copy(platform::XPUPlace dst_place, } #endif +#ifdef PADDLE_WITH_ASCEND_CL +template <> +void Copy(platform::NPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetNPUDeviceId(dst_place.device); + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by thream(" << stream << ")"; + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream); + } else { + platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE); + } +} + +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::NPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetNPUDeviceId(src_place.device); + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by thream(" << stream << ")"; + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream); + } else { + platform::RecordEvent record_event("GpuMemcpySync:NPU->CPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST); + } +} + +template <> +void Copy(platform::NPUPlace dst_place, + void* dst, + platform::NPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by stream(" << stream << ")"; + if (dst_place == src_place) { + platform::SetNPUDeviceId(src_place.device); + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE, + stream); + } else { + platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } else { + if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) { + PADDLE_THROW(platform::errors::Unavailable( + "Peer access between NPU places is not allowed.")); + } + if (stream) { + // TODO(zhiqiu): support peer access? + platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE, + stream); + } else { + platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } +} +#endif + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K diff --git a/paddle/fluid/memory/memcpy.h b/paddle/fluid/memory/memcpy.h index 25490f28b659876ddd1e9e0eef7f23062791a05e..c630437224cd093438df5d8d8a58a5c8f6ab2ad2 100644 --- a/paddle/fluid/memory/memcpy.h +++ b/paddle/fluid/memory/memcpy.h @@ -52,7 +52,27 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, gpuStream_t stream); +#endif +#ifdef PADDLE_WITH_ASCEND_CL +/** + * \brief Copy memory from one place to another place. + * + * \param[in] DstPlace Destination allocation place (CPU or NPU). + * \param[in] dst Destination memory address. + * \param[in] SrcPlace Source allocation place (CPU or NPU). + * \param[in] src Source memory address. + * \param[in] num memory size in bytes to copy. + * \param[in] stream NPU stream. + * + * \note For NPU memory copy, NPU stream need to be specified + * for asynchronously memory copy. + * + */ +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, + aclrtStream stream); #endif + } // namespace memory } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ed8787275322859cfe0aff602472cb94bf5033d2..dac8c7b03e5174fe5c6354c3f882bec8fa4b3085 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -123,6 +123,11 @@ if (WITH_ASCEND) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} ascend_wrapper) endif() +if (WITH_ASCEND_CL) + cc_library(npu_op_runner SRCS npu_op_runner.cc DEPS operator npu_info) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner) +endif() + # FIXME(typhoonzero): operator deps may not needed. # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt index 06ca98e526e95b414584f9634a3d42f84d6b369f..216a3f79d6f920ff996f0b0788565d52d5bf3aff 100644 --- a/paddle/fluid/operators/elementwise/CMakeLists.txt +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -8,3 +8,7 @@ register_operators(DEPS op_version_registry) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) + +if(WITH_ASCEND_CL) +cc_test(elementwise_op_npu_test SRCS elementwise_op_npu_test.cc DEPS op_registry elementwise_add_op elementwise_sub_op scope device_context enforce executor) +endif() diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e7e5e02c0181f8828a59b9403ac24f40347f8b6 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class ElementwiseAddNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto runner = NpuOpRunner("Add", {*x, *y}, {*out}, {}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + elementwise_add, + ops::ElementwiseAddNPUKernel); +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a2a21647083bfda097b75326814fd34d2bdd689 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc @@ -0,0 +1,181 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(elementwise_add); +USE_OP_DEVICE_KERNEL(elementwise_add, NPU); +USE_OP(elementwise_sub); +USE_OP_DEVICE_KERNEL(elementwise_sub, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + auto y = scope->Var("Y"); + auto tensor_y = y->GetMutable(); + + std::vector init_x; + for (int64_t i = 0; i < 10 * 10; ++i) { + init_x.push_back(static_cast(1.0)); + } + + std::vector init_y; + for (int64_t i = 0; i < 10 * 10; ++i) { + init_y.push_back(static_cast(2.0)); + } + + TensorFromVector(init_x, ctx, tensor_x); + tensor_x->Resize({10, 10}); + TensorFromVector(init_y, ctx, tensor_y); + tensor_y->Resize({10, 10}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp(op_type, {{"X", {"X"}}, {"Y", {"Y"}}}, + {{"Out", {"Out"}}}, attrs); + + op->Run(*scope, place); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + ctx.Wait(); + float expected; + if (op_type == "elementwise_add") { + expected = 3.0; + } else if (op_type == "elementwise_sub") { + expected = -1.0; + } + EXPECT_EQ(out_vec.size(), init_x.size()); + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_EQ(out_vec[i], static_cast(expected)); + } +} + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto dout = scope->Var("DOut"); + auto tensor_dout = dout->GetMutable(); + tensor_dout->Resize({2, 3, 5}); + + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + tensor_x->Resize({2, 3, 5}); + + auto y = scope->Var("Y"); + auto tensor_y = y->GetMutable(); + tensor_y->Resize({1, 5}); + + auto dx = scope->Var("DX"); + auto tensor_dx = dx->GetMutable(); + + auto dy = scope->Var("DY"); + auto tensor_dy = dy->GetMutable(); + + std::vector init_dout; + for (int64_t i = 0; i < tensor_dout->numel(); ++i) { + init_dout.push_back(static_cast(1.0)); + } + + TensorFromVector(init_dout, ctx, tensor_dout); + tensor_dout->Resize({2, 3, 5}); + + ctx.Wait(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp( + op_type, {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}}, + {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs); + + auto place = ctx.GetPlace(); + op->Run(*scope, place); + + std::vector dx_vec; + TensorToVector(*tensor_dx, ctx, &dx_vec); + + std::vector dy_vec; + TensorToVector(*tensor_dy, ctx, &dy_vec); + + ctx.Wait(); + float expected_x, expected_y; + if (op_type == "elementwise_add_grad") { + expected_x = 1.0; + expected_y = 6.0; + } else if (op_type == "elementwise_sub_grad") { + expected_x = 1.0; + expected_y = -6.0; + } + + for (uint32_t i = 0; i < dx_vec.size(); i++) { + EXPECT_EQ(dx_vec[i], static_cast(expected_x)); + } + for (uint32_t i = 0; i < dy_vec.size(); i++) { + EXPECT_EQ(dy_vec[i], static_cast(expected_y)); + } +} + +TEST(elementwise_add, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_add"); +} + +TEST(elementwise_sub, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_sub"); +} + +TEST(elementwise_sub, NPU_fp16) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "elementwise_sub"); +} + +TEST(elementwise_sub_grad, NPU) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx, "elementwise_sub_grad"); +} diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e47c38daee8ba028668f88736ca5e7266ee4bb00 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc @@ -0,0 +1,171 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwiseSubNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto runner = NpuOpRunner("Sub", {*x, *y}, {*out}, {}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class ElementwiseSubGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + dx->mutable_data(ctx.GetPlace()); + dy->mutable_data(ctx.GetPlace()); + + // NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with + // default axis=-1? + // So, the sub_grad should do reduce if needed. + // For example, the shape of each variable in elementwise_sub: + // x, dx: [2, 3, 5] + // y, dy: [1, 5] + // out, dout: [2, 3, 5] + // Then, out = x - y => dx = dout, dy = -dout + // And, the shape of dy can be computed by two stages reduce, + // 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false. + // 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true. + + auto stream = + ctx.template device_context() + .stream(); + // For dx + // stage 1 + auto reduce_ndim = dout->dims().size() - dx->dims().size(); + std::vector axes; + for (auto i = 0; i < reduce_ndim; ++i) { + axes.push_back(i); + } + auto tmp_dout = dout; + Tensor reduced_dout(dx->type()); + if (axes.size() != 0) { + std::vector reduced_dout_dims; + for (auto i = reduce_ndim; i < dout->dims().size(); ++i) { + reduced_dout_dims.push_back(dout->dims()[i]); + } + reduced_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); + tmp_dout = &reduced_dout; + } + + // stage 2 + axes.clear(); + for (auto i = 0; i < dx->dims().size(); ++i) { + if (dx->dims()[i] == 1) { + axes.push_back(i); + } + } + if (axes.size() != 0) { + auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx}, + {{"axes", axes}, {"keep_dims", true}}); + runner.Run(stream); + } else { + framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx); + } + + // For dy + // stage 1 + reduce_ndim = dout->dims().size() - dy->dims().size(); + axes.clear(); + for (auto i = 0; i < reduce_ndim; ++i) { + axes.push_back(i); + } + tmp_dout = dout; + Tensor reduced_dy(dy->type()); + + if (axes.size() != 0) { + std::vector reduced_dout_dims; + for (auto i = reduce_ndim; i < dout->dims().size(); ++i) { + reduced_dout_dims.push_back(dout->dims()[i]); + } + reduced_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); + tmp_dout = &reduced_dout; + } + + // stage 2 + axes.clear(); + auto* tmp_dy = tmp_dout; + for (auto i = 0; i < dy->dims().size(); ++i) { + if (dy->dims()[i] == 1) { + axes.push_back(i); + } + } + if (axes.size() != 0) { + reduced_dy.Resize(dy->dims()); + reduced_dy.mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {reduced_dy}, + {{"axes", axes}, {"keep_dims", true}}); + runner.Run(stream); + tmp_dy = &reduced_dy; + } + + // stage 3, negative + auto runner = NpuOpRunner("Neg", {*tmp_dy}, {*dy}, {}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + elementwise_sub, + ops::ElementwiseSubNPUKernel, + ops::ElementwiseSubNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_sub_grad, + ops::ElementwiseSubGradNPUKernel, + ops::ElementwiseSubGradNPUKernel); +#endif diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 5242d03c11c997a59b5a3194c65466cb3c15e637..68179a68574a016b9dde0a67663fa6e1c2967405 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -149,6 +149,13 @@ void set_constant_with_place( PADDLE_THROW(platform::errors::Unimplemented("XPUPlace is not supported")); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + PADDLE_THROW(platform::errors::Unimplemented("NPUPlace is not supported")); +} + template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..7af6de5224145b991b9f4f17eebbf4c3748fac59 --- /dev/null +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -0,0 +1,260 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/npu_op_runner.h" + +#include +#include + +#include +#include +#include + +#include "acl/acl.h" +#include "acl/acl_op_compiler.h" + +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace operators { + +static std::map + DTYPE_2_ACL_DTYPE = { + {framework::proto::VarType::BOOL, ACL_BOOL}, + {framework::proto::VarType::INT16, ACL_INT16}, + {framework::proto::VarType::INT32, ACL_INT32}, + {framework::proto::VarType::INT64, ACL_INT64}, + {framework::proto::VarType::FP16, ACL_FLOAT16}, + {framework::proto::VarType::FP32, ACL_FLOAT}, + {framework::proto::VarType::FP64, ACL_DOUBLE}, +}; + +static std::map DATA_LAYOUT_2_ACL_FORMAT = { + {DataLayout::kNCHW, ACL_FORMAT_NCHW}, + {DataLayout::kNHWC, ACL_FORMAT_NHWC}, + {DataLayout::kAnyLayout, ACL_FORMAT_ND}, +}; + +aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype) { + auto iter = DTYPE_2_ACL_DTYPE.find(dtype); + PADDLE_ENFORCE_NE(iter, DTYPE_2_ACL_DTYPE.end(), + platform::errors::NotFound( + "The data type (%s) can not convert to ACL data type.", + framework::DataTypeToString(dtype))); + return iter->second; +} + +aclFormat ConvertToNpuFormat(DataLayout layout) { + auto iter = DATA_LAYOUT_2_ACL_FORMAT.find(layout); + PADDLE_ENFORCE_NE( + iter, DATA_LAYOUT_2_ACL_FORMAT.end(), + platform::errors::NotFound( + "The data type (%s) can not convert to ACL data type.", layout)); + return iter->second; +} + +NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) { + attr_ = aclopCreateAttr(); +} + +NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) + : op_type_(op_type) { + attr_ = aclopCreateAttr(); + AddInputs(inputs); + AddOutputs(outputs); + AddAttrs(attrs); +} + +NpuOpRunner::~NpuOpRunner() { + // TODO(zhiqiu): handle free +} + +const std::string &NpuOpRunner::Type() { return op_type_; } + +NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, + const Attribute &attr) { + if (attr.type() == typeid(bool)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr))); + } else if (attr.type() == typeid(int)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrInt(attr_, name.c_str(), BOOST_GET_CONST(int, attr))); + + } else if (attr.type() == typeid(int64_t)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrInt(attr_, name.c_str(), BOOST_GET_CONST(int64_t, attr))); + } else if (attr.type() == typeid(float)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrFloat(attr_, name.c_str(), BOOST_GET_CONST(float, attr))); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector cast_a; + for (auto it : a) { + cast_a.push_back(static_cast(it)); + } + PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrListBool( + attr_, name.c_str(), cast_a.size(), cast_a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector cast_a; + for (auto it : a) { + cast_a.push_back(static_cast(it)); + } + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListInt(attr_, name.c_str(), cast_a.size(), cast_a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListInt(attr_, name.c_str(), a.size(), a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListFloat(attr_, name.c_str(), a.size(), a.data())); + } else if (attr.type() == typeid(std::string)) { + auto a = BOOST_GET_CONST(std::string, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrString(attr_, name.c_str(), a.c_str())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector s; + for (auto &it : a) { + s.push_back(it.data()); + } + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data())); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Can not convert attribubte '%s' to convert to aclopAttr", name)); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddAttrs(const AttributeMap &attrs) { + for (const auto &pair : attrs) { + AddAttr(pair.first, pair.second); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddInput(const Tensor &tensor) { + // create aclTensorDesc + input_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + input_buffers_.emplace_back(CreateDataBuffer(tensor)); + return *this; +} + +NpuOpRunner &NpuOpRunner::AddOutput(const Tensor &tensor) { + // create aclTensorDesc + output_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + output_buffers_.emplace_back(CreateDataBuffer(tensor)); + return *this; +} + +NpuOpRunner &NpuOpRunner::AddInputs(const std::vector &tensors) { + for (auto tensor : tensors) { + // create aclTensorDesc + input_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + input_buffers_.emplace_back(CreateDataBuffer(tensor)); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddOutputs(const std::vector &tensors) { + for (auto tensor : tensors) { + // create aclTensorDesc + output_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + output_buffers_.emplace_back(CreateDataBuffer(tensor)); + } + return *this; +} + +aclTensorDesc *NpuOpRunner::GetInputDesc(size_t index) { + PADDLE_ENFORCE_LT(index, input_descs_.size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + Type(), index, input_descs_.size())); + return input_descs_[index]; +} + +aclTensorDesc *NpuOpRunner::GetOutputDesc(size_t index) { + PADDLE_ENFORCE_LT(index, output_descs_.size(), + platform::errors::OutOfRange( + "The index should be less than the size of output of " + "operator %s, but got index is %d and size is %d", + Type(), index, output_descs_.size())); + return output_descs_[index]; +} + +std::vector &NpuOpRunner::GetInputDescs() { + return input_descs_; +} + +std::vector &NpuOpRunner::GetOutputDescs() { + return output_descs_; +} + +std::vector &NpuOpRunner::GetInputBuffers() { + return input_buffers_; +} + +std::vector &NpuOpRunner::GetOutputBuffers() { + return output_buffers_; +} + +aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { + auto dtype = ConvertToNpuDtype(tensor.type()); + auto format = ConvertToNpuFormat(tensor.layout()); + auto dims = framework::vectorize(tensor.dims()); + + VLOG(4) << dtype << " " << dims.size() << " " << dims[0] << "," << dims[1] + << " " << format; + + auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format); + PADDLE_ENFORCE_NOT_NULL( + desc, platform::errors::External("Call aclCreateTensorDesc failed.")); + return desc; +} + +aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) { + void *ptr = tensor.data(); + VLOG(4) << "ptr: " << ptr << ", size: " << tensor.memory_size(); + auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size()); + PADDLE_ENFORCE_NOT_NULL( + buffer, platform::errors::External("Call aclCreateDataBuffer failed.")); + return buffer; +} + +void NpuOpRunner::Run(aclrtStream stream) { + VLOG(4) << "op_type: " << op_type_; + VLOG(4) << "input_desc.size: " << input_descs_.size(); + VLOG(4) << "output_desc.size: " << output_descs_.size(); + VLOG(4) << "stream: " << stream; + VLOG(4) << "attr: " << attr_; + aclError ret = aclopCompileAndExecute( + op_type_.c_str(), input_descs_.size(), input_descs_.data(), + input_buffers_.data(), output_descs_.size(), output_descs_.data(), + output_buffers_.data(), attr_, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, + stream); + VLOG(4) << "after aclopCompileAndExecute: " << ret; + PADDLE_ENFORCE_NPU_SUCCESS(ret); +} +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/npu_op_runner.h b/paddle/fluid/operators/npu_op_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..c69d8441e5def8b24aea0b094560103bf21a7442 --- /dev/null +++ b/paddle/fluid/operators/npu_op_runner.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include + +#include "acl/acl.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +using Attribute = framework::Attribute; +using AttributeMap = framework::AttributeMap; + +class NpuOpRunner { + public: + explicit NpuOpRunner(std::string op_type); + explicit NpuOpRunner(std::string op_type, + const std::vector &inputs = {}, + const std::vector &outputs = {}, + const AttributeMap &attrs = {}); + + ~NpuOpRunner(); + + const std::string &Type(); + + NpuOpRunner &AddAttr(const std::string &name, const Attribute &attr); + + NpuOpRunner &AddAttrs(const AttributeMap &attrs); + + NpuOpRunner &AddInput(const Tensor &tensor); + + NpuOpRunner &AddOutput(const Tensor &tensor); + + NpuOpRunner &AddInputs(const std::vector &tensors); + + NpuOpRunner &AddOutputs(const std::vector &tensors); + + aclTensorDesc *GetInputDesc(size_t index); + + aclTensorDesc *GetOutputDesc(size_t index); + + std::vector &GetInputDescs(); + + std::vector &GetOutputDescs(); + + std::vector &GetInputBuffers(); + + std::vector &GetOutputBuffers(); + + void Run(aclrtStream stream); + + private: + aclTensorDesc *CreateTensorDesc(Tensor tensor); + aclDataBuffer *CreateDataBuffer(Tensor tensor); + + private: + std::string op_type_; + std::vector input_buffers_; + std::vector output_buffers_; + std::vector input_descs_; + std::vector output_descs_; + aclopAttr *attr_{nullptr}; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 1e16008f36bb7784ca850cf87d66d66e4ab86c41..584dbd4756aa0928f1f9f8edfdc88e957c1258dc 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -76,6 +76,10 @@ if(WITH_ASCEND) cc_library(ascend_npu_info SRCS ascend_npu_info.cc DEPS gflags glog enforce atlas_acl) endif() +if(WITH_ASCEND_CL) + cc_library(npu_info SRCS npu_info.cc DEPS gflags glog enforce monitor ascendcl acl_op_compiler) +endif() + add_subdirectory(dynload) add_subdirectory(stream) @@ -91,11 +95,20 @@ IF(WITH_GPU OR WITH_ROCM) set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream) ENDIF() +IF(WITH_ASCEND_CL) + set(NPU_CTX_DEPS npu_stream npu_info) +ENDIF() + IF(WITH_MKLDNN) set(MKLDNN_CTX_DEPS mkldnn) ELSE() set(MKLDNN_CTX_DEPS) ENDIF() + +IF(WITH_ASCEND_CL) +cc_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) +ENDIF() + IF(WITH_GPU) nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) ENDIF() @@ -105,6 +118,8 @@ ENDIF() IF(WITH_GPU OR WITH_ROCM) set(STREAM_CALLBACK_DEPS stream_callback_manager) +ELSEIF(WITH_ASCEND_CL) + set(STREAM_CALLBACK_DEPS stream_callback_manager) ELSE() set(STREAM_CALLBACK_DEPS) ENDIF() @@ -118,7 +133,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} - place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS}) cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 22daaf101cf200891ee98fd2c5c12b944d76825c..a0ade3898c336bf9168c9cc9dfbf02bfe0126fb4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -78,13 +78,13 @@ bool AllowTF32Cudnn() { return allow_tf32_cudnn; } DeviceContextPool* DeviceContextPool::pool = nullptr; platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { + VLOG(4) << "DeviceContextPool Get: " << place; auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW(platform::errors::Unimplemented( "Place %s is not supported. Please check that your paddle compiles " - "with WITH_GPU or WITH_XPU option or check that your train process " - "hold the " - "correct gpu_id if you use Executor.", + "with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that " + "your train process set the correct device id if you use Executor.", place)); } return it->second.get().get(); @@ -145,6 +145,14 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( platform::errors::Unimplemented("XPUPlace is not supported. Please " "re-compile with WITH_XPU option.")); +#endif + } else if (platform::is_npu_place(p)) { +#ifdef PADDLE_WITH_ASCEND_CL + EmplaceDeviceContext(&device_contexts_, p); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "NPUPlace is not supported. Please " + "re-compile with WITH_ASCEND_CL option.")); #endif } } @@ -229,8 +237,35 @@ Place XPUDeviceContext::GetPlace() const { return place_; } xpu::Context* XPUDeviceContext::x_context() const { return context_; } #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#ifdef PADDLE_WITH_ASCEND_CL +NPUDeviceContext::NPUDeviceContext(NPUPlace place) : place_(place) { + NPUDeviceGuard guard(place_.device); + // PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateContext(&context_, place_.device)); + // NOTE(zhiqiu): Usually, no need to create context explicitly, + // ACL creates a default context which contains 1 default stream + // and 1 sync strean after aclrtSetDevice. + PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetCurrentContext(&context_)); + stream_.reset(new stream::NPUStream(place)); +} + +NPUDeviceContext::~NPUDeviceContext() { + // NPUDeviceGuard guard(place_.device); + // PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyContext(context_)); +} +void NPUDeviceContext::Wait() const { + NPUDeviceGuard guard(place_.device); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeDevice()); +} + +aclrtStream NPUDeviceContext::stream() const { return stream_->raw_stream(); } + +Place NPUDeviceContext::GetPlace() const { return place_; } + +aclrtContext NPUDeviceContext::context() const { return context_; } +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class EigenCudaStreamDevice : public Eigen::StreamInterface { public: EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { @@ -706,6 +741,5 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( } #endif - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 02ad22f780f8d589e04e69b5d11d61d55afd0e42..face048f28e8340b9f6e2b3c3d0fcbc4552eea52 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -57,6 +57,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/stream/cuda_stream.h" #endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/platform/stream/npu_stream.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" namespace Eigen { @@ -69,6 +72,11 @@ struct GpuDevice; #include "paddle/fluid/platform/xpu_info.h" #endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "acl/acl.h" +#include "paddle/fluid/platform/npu_info.h" +#endif + namespace paddle { namespace platform { @@ -87,11 +95,13 @@ enum DeviceType { CPU = 0, CUDA = 1, XPU = 2, + NPU = 3, }; constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kXPU = DeviceType::XPU; +constexpr DeviceType kNPU = DeviceType::NPU; class DeviceContext { public: @@ -163,8 +173,52 @@ struct DefaultDeviceContextType { }; #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#ifdef PADDLE_WITH_ASCEND_CL +class NPUDeviceContext : public DeviceContext { + public: + explicit NPUDeviceContext(NPUPlace place); + virtual ~NPUDeviceContext(); + Eigen::DefaultDevice* eigen_device() const { return nullptr; } + Place GetPlace() const override; + aclrtContext context() const; + /*! \brief Wait for all operations completion in the stream. */ + void Wait() const override; + + /*! \brief Return npu stream in the device context. */ + aclrtStream stream() const; + +#ifdef PADDLE_WITH_ASCEND_HCCL + /*! \brief Return bkcl context. */ + HCCLContext_t hccl_context() const { return hccl_context_; } + + /*! \brief Set bkcl context. */ + void set_hccl_context(HCCLContext_t context) { hccl_context_ = context; } +#endif + + private: + NPUPlace place_; + aclrtContext context_; +#ifdef PADDLE_WITH_ASCEND_HCCL + HCCLContext_t hccl_context_; +#endif + + // Need to be the same with other DeviceContext, + // Eventhough eigen_device_ is not used in NPU + // NOTE(zhiqiu): why need? + std::unique_ptr eigen_device_; + std::shared_ptr stream_; + + DISABLE_COPY_AND_ASSIGN(NPUDeviceContext); +}; + +template <> +struct DefaultDeviceContextType { + using TYPE = NPUDeviceContext; +}; +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class CudnnWorkspaceHandle; class EigenCudaStreamDevice; diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index f5045ff004ee9b8391a879c402c679f0078487a2..4828a97e4df4d54000739adff28bc861d2da2213 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_CUDA #include #include #include // NOLINT @@ -186,3 +187,5 @@ CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) } // namespace dynload } // namespace platform } // namespace paddle + +#endif diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 47ade89ff2df3f53e3d99fbe1b203314b9edabd2..f0809d34d493e93c7a71f396fadde57d15fa23ff 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -45,6 +45,10 @@ limitations under the License. */ #include // NOLINT #endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "acl/acl.h" +#endif // PADDLE_WITH_ASCEND_CL + #include #include #include @@ -970,7 +974,6 @@ DEFINE_CUDA_STATUS_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); #endif - } // namespace details #define PADDLE_ENFORCE_CUDA_SUCCESS(COND) \ @@ -1204,5 +1207,41 @@ inline void retry_sleep(unsigned millisecond) { #undef DEFINE_CUDA_STATUS_TYPE #endif // PADDLE_WITH_HIP +#ifdef PADDLE_WITH_ASCEND_CL +namespace details { +template +struct NPUStatusType {}; + +#define DEFINE_NPU_STATUS_TYPE(type, success_value) \ + template <> \ + struct NPUStatusType { \ + using Type = type; \ + static constexpr Type kSuccess = success_value; \ + } + +DEFINE_NPU_STATUS_TYPE(aclError, ACL_ERROR_NONE); +} // namespace details + +inline std::string build_npu_error_msg(aclError stat) { + std::ostringstream sout; + sout << " ACL error, the error code is : " << stat << ". "; + return sout.str(); +} + +#define PADDLE_ENFORCE_NPU_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + using __NPU_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::NPUStatusType< \ + __NPU_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + ::paddle::platform::build_npu_error_msg(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) +#endif // PADDLE_WITH_ASCEND_CL + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index fa77c0be037df32cd904b47f8d1ed2ee7208eae9..83b9544d23267be9de80ce9cd054a9b40bf892aa 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -45,7 +45,10 @@ DEFINE_bool(check_nan_inf, false, "Checking whether operator produce NAN/INF or not. It will be " "extremely slow so please use this flag wisely."); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +// NOTE(zhiqiu): better to share the flags, otherwise we will have too many +// flags. +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_ASCEND_CL) /** * CUDA related related FLAG @@ -84,8 +87,15 @@ DEFINE_string(selected_gpus, "", "share-memory only."); #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_ASCEND_CL) +DEFINE_string(selected_npus, "", + "A list of device ids separated by comma, like: 0,1,2,3. " + "This option is useful when doing multi process training and " + "each process have only one device (NPU). If you want to use " + "all visible devices, set this to empty string."); +#endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) /** * CUDNN related FLAG * Name: FLAGS_cudnn_deterministic @@ -377,7 +387,10 @@ DEFINE_double( "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," "reserve the rest for page tables, etc"); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +// NOTE(zhiqiu): better to share the flags, otherwise we will have too many +// flags. +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_ASCEND_CL) /** * Memory related FLAG diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 3769428c9df86224abfdc4b67c066a3dc553d5c6..2e66e3e36d0b219301a1a8de0a9b495ad9026942 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -102,6 +102,7 @@ static int GetCUDADeviceCountImpl() { } int GetCUDADeviceCount() { + // cache the count static auto dev_cnt = GetCUDADeviceCountImpl(); return dev_cnt; } diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index ea89082733a80fa9e8e79129839f1120f344cc55..ac6988d350f4f38c6e8da2a655c29069b8d0eda6 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/npu_info.h" +#include "paddle/fluid/string/split.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -63,6 +65,7 @@ namespace framework { std::once_flag gflags_init_flag; std::once_flag glog_init_flag; +std::once_flag npu_init_flag; bool InitGflags(std::vector args) { bool successed = false; @@ -145,6 +148,17 @@ void InitDevices() { } catch (const std::exception &exp) { LOG(WARNING) << "Compiled with WITH_XPU, but no XPU found in runtime."; } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + // NOTE(zhiqiu): use singleton to explicitly init and finalize ACL + platform::AclInstance::Instance(); // NOLINT + try { + // use user specified XPUs in single-node multi-process mode. + devices = platform::GetSelectedNPUDevices(); + } catch (const std::exception &exp) { + LOG(WARNING) + << "Compiled with PADDLE_WITH_ASCEND_CL, but no NPU found in runtime."; + } #endif InitDevices(devices); } @@ -165,6 +179,9 @@ void InitDevices(const std::vector devices) { #endif #ifdef PADDLE_WITH_XPU places.emplace_back(platform::XPUPlace(devices[i])); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + places.emplace_back(platform::NPUPlace(devices[i])); #endif } places.emplace_back(platform::CPUPlace()); diff --git a/paddle/fluid/platform/monitor.cc b/paddle/fluid/platform/monitor.cc index 76554012bf51e34fc99db7759404f0e8d6f96cd6..1b44cb196547c2d26cdd5ae72c3331022f834657 100644 --- a/paddle/fluid/platform/monitor.cc +++ b/paddle/fluid/platform/monitor.cc @@ -35,3 +35,13 @@ DEFINE_INT_STATUS(STAT_gpu12_mem_size) DEFINE_INT_STATUS(STAT_gpu13_mem_size) DEFINE_INT_STATUS(STAT_gpu14_mem_size) DEFINE_INT_STATUS(STAT_gpu15_mem_size) + +// For Ascend NPU +DEFINE_INT_STATUS(STAT_npu0_mem_size) +DEFINE_INT_STATUS(STAT_npu1_mem_size) +DEFINE_INT_STATUS(STAT_npu2_mem_size) +DEFINE_INT_STATUS(STAT_npu3_mem_size) +DEFINE_INT_STATUS(STAT_npu4_mem_size) +DEFINE_INT_STATUS(STAT_npu5_mem_size) +DEFINE_INT_STATUS(STAT_npu6_mem_size) +DEFINE_INT_STATUS(STAT_npu7_mem_size) diff --git a/paddle/fluid/platform/monitor.h b/paddle/fluid/platform/monitor.h index b57fae9daac41f37829309c4bc5f58fb2606ca02..0eb9448ce0fad4e1caadb3e08140417294d5d0e7 100644 --- a/paddle/fluid/platform/monitor.h +++ b/paddle/fluid/platform/monitor.h @@ -187,3 +187,13 @@ class StatRegistry { USE_INT_STAT(STAT_gpu13_mem_size); \ USE_INT_STAT(STAT_gpu14_mem_size); \ USE_INT_STAT(STAT_gpu15_mem_size) + +#define USE_NPU_MEM_STAT \ + USE_INT_STAT(STAT_npu0_mem_size); \ + USE_INT_STAT(STAT_npu1_mem_size); \ + USE_INT_STAT(STAT_npu2_mem_size); \ + USE_INT_STAT(STAT_npu3_mem_size); \ + USE_INT_STAT(STAT_npu4_mem_size); \ + USE_INT_STAT(STAT_npu5_mem_size); \ + USE_INT_STAT(STAT_npu6_mem_size); \ + USE_INT_STAT(STAT_npu7_mem_size) diff --git a/paddle/fluid/platform/npu_info.cc b/paddle/fluid/platform/npu_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..6920436399312e18cd9634327e080a4d5605038c --- /dev/null +++ b/paddle/fluid/platform/npu_info.cc @@ -0,0 +1,409 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/npu_info.h" +#include +#include +#include + +#include "gflags/gflags.h" + +#include "paddle/fluid/platform/lock_guard_ptr.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/monitor.h" +#include "paddle/fluid/string/split.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); +DECLARE_uint64(initial_gpu_memory_in_mb); +DECLARE_uint64(reallocate_gpu_memory_in_mb); +DECLARE_bool(enable_cublas_tensor_op_math); +DECLARE_uint64(gpu_memory_limit_mb); +DECLARE_string(selected_npus); + +constexpr static float fraction_reserve_gpu_memory = 0.05f; + +USE_NPU_MEM_STAT; + +namespace paddle { +namespace platform { + +static int GetNPUDeviceCountImpl() { + uint32_t count; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetDeviceCount(&count)); + return count; +} + +int GetNPUDeviceCount() { + static auto dev_cnt = GetNPUDeviceCountImpl(); + return dev_cnt; +} + +int NPUCanAccessPeer(int src, int dst) { + int can = 0; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtDeviceCanAccessPeer(&can, src, dst)); + return can; +} + +// For example, "1.0.1" +std::string GetNPURuntimeVersion(int id) { + PADDLE_ENFORCE_LT(id, GetNPUDeviceCount(), + platform::errors::InvalidArgument( + "Device id must be less than NPU count, " + "but received id is: %d. NPU count is: %d.", + id, GetNPUDeviceCount())); + int major = 0, minor = 0, patch = 0; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetVersion(&major, &minor, &patch)); + return string::Sprintf("%d.%d.%d", major, minor, patch); +} + +int GetCurrentNPUDeviceId() { + int device_id; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtGetDevice(&device_id)); + return device_id; +} + +//! Get a list of device ids from environment variable or use all. +std::vector GetSelectedNPUDevices() { + // use user specified NPUs in single-node multi-process mode. + std::vector devices; + if (!FLAGS_selected_npus.empty()) { + auto devices_str = paddle::string::Split(FLAGS_selected_npus, ','); + for (auto id : devices_str) { + devices.push_back(atoi(id.c_str())); + } + } else { + int count = GetNPUDeviceCount(); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + return devices; +} + +void SetNPUDeviceId(int id) { + PADDLE_ENFORCE_LT(id, GetNPUDeviceCount(), + platform::errors::InvalidArgument( + "Device id must be less than NPU count, " + "but received id is: %d. NPU count is: %d.", + id, GetNPUDeviceCount())); + // NOTE(zihqiu): It is recommended to call aclrtSetDevice and aclrtResetDevice + // pairly. + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSetDevice(id)); +} + +void ResetNPUDeviceId(int id) { + PADDLE_ENFORCE_LT(id, GetNPUDeviceCount(), + platform::errors::InvalidArgument( + "Device id must be less than NPU count, " + "but received id is: %d. NPU count is: %d.", + id, GetNPUDeviceCount())); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtResetDevice(id)); +} + +void NPUMemoryUsage(size_t *available, size_t *total) { + size_t actual_available, actual_total; + RecordedNPUMemGetInfo(available, total, &actual_available, &actual_total, + platform::GetCurrentNPUDeviceId()); +} + +size_t NPUAvailableMemToAlloc() { + size_t total = 0; + size_t available = 0; + NPUMemoryUsage(&available, &total); + size_t reserving = + static_cast(fraction_reserve_gpu_memory * available); + // If available size is less than minimum chunk size, no usable memory exists + size_t available_to_alloc = available - reserving; + size_t min_chunk_size = NPUMinChunkSize(); + if (available_to_alloc < min_chunk_size) { + available_to_alloc = 0; + } + VLOG(10) << "NPU usage " << (available >> 20) << "M/" << (total >> 20) + << "M, " << (available_to_alloc >> 20) << "M available to allocate"; + return available_to_alloc; +} + +size_t NPUMaxAllocSize() { + return std::max(NPUInitAllocSize(), NPUReallocSize()); +} + +static size_t NPUAllocSize(bool realloc) { + size_t available_to_alloc = NPUAvailableMemToAlloc(); + PADDLE_ENFORCE_GT( + available_to_alloc, 0, + platform::errors::ResourceExhausted("Not enough available NPU memory.")); + // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be + // allocated by fraction + size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb + : FLAGS_initial_gpu_memory_in_mb; + size_t alloc_bytes = + (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc * + FLAGS_fraction_of_gpu_memory_to_use); + PADDLE_ENFORCE_GE( + available_to_alloc, alloc_bytes, + platform::errors::ResourceExhausted("Not enough available NPU memory.")); + VLOG(10) << "Alloc size is " << (alloc_bytes >> 20) + << " MiB, is it Re-alloc: " << realloc; + return alloc_bytes; +} + +size_t NPUInitAllocSize() { return NPUAllocSize(/* realloc = */ false); } + +size_t NPUReallocSize() { return NPUAllocSize(/* realloc = */ true); } + +size_t NPUMinChunkSize() { + // Allow to allocate the minimum chunk size is 256 bytes. + return 1 << 8; +} + +size_t NPUMaxChunkSize() { + size_t max_chunk_size = NPUMaxAllocSize(); + VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M"; + return max_chunk_size; +} + +void NPUMemcpyAsync(void *dst, const void *src, size_t count, + enum aclrtMemcpyKind kind, aclrtStream stream, + size_t dst_max_count) { + dst_max_count = dst_max_count ? dst_max_count : count; + VLOG(4) << dst << " " << dst_max_count << " " << src << " " << count << " " + << kind << " " << stream; + PADDLE_ENFORCE_NPU_SUCCESS( + aclrtMemcpyAsync(dst, dst_max_count, src, count, kind, stream)); +} + +void NPUMemcpySync(void *dst, const void *src, size_t count, + enum aclrtMemcpyKind kind, size_t dst_max_count) { + // NOTE(zhiqiu): The default max_count is count + dst_max_count = dst_max_count ? dst_max_count : count; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtMemcpy(dst, dst_max_count, src, count, kind)); +} + +void NPUMemcpyPeerASync(void *dst, int dst_device, const void *src, + size_t count, enum aclrtMemcpyKind kind, + aclrtStream stream, size_t dst_max_count) { + dst_max_count = dst_max_count ? dst_max_count : count; + PADDLE_ENFORCE_NPU_SUCCESS( + aclrtMemcpyAsync(dst, dst_max_count, src, count, kind, stream)); +} + +void NPUMemcpyPeerSync(void *dst, int dst_device, const void *src, size_t count, + enum aclrtMemcpyKind kind, size_t dst_max_count) { + // NOTE(zhiqiu): The default max_count is count + dst_max_count = dst_max_count ? dst_max_count : count; + PADDLE_ENFORCE_NPU_SUCCESS(aclrtMemcpy(dst, dst_max_count, src, count, kind)); +} + +void NPUMemsetAsync(void *dst, int value, size_t count, aclrtStream stream, + size_t max_count) { + max_count = max_count ? max_count : count; + PADDLE_ENFORCE_NPU_SUCCESS( + aclrtMemsetAsync(dst, max_count, value, count, stream)); +} + +void NPUStreamSync(aclrtStream stream) { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream)); +} + +static void RaiseNonOutOfMemoryError(aclError *status) { + if (*status == ACL_ERROR_BAD_ALLOC) { + *status = ACL_ERROR_NONE; + } + PADDLE_ENFORCE_NPU_SUCCESS(*status); +} + +class RecordedNPUMallocHelper { + private: + explicit RecordedNPUMallocHelper(int dev_id, uint64_t limit_size = 0) + : dev_id_(dev_id), limit_size_(limit_size) { + if (NeedRecord()) { + mtx_.reset(new std::mutex()); + } + } + + DISABLE_COPY_AND_ASSIGN(RecordedNPUMallocHelper); + + public: + static RecordedNPUMallocHelper *Instance(int dev_id) { + std::call_once(once_flag_, [] { + int dev_cnt = GetNPUDeviceCount(); + instances_.reserve(dev_cnt); + for (int i = 0; i < dev_cnt; ++i) { + // NOTE(zhiqiu): share the flags with gpu, avoid more flags. + instances_.emplace_back( + new RecordedNPUMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20)); + } + }); + + PADDLE_ENFORCE_GE( + dev_id, 0, + platform::errors::OutOfRange( + "Device id must be not less than 0, but got %d.", dev_id)); + PADDLE_ENFORCE_LT( + dev_id, instances_.size(), + platform::errors::OutOfRange("Device id %d exceeds npu card number %d.", + dev_id, instances_.size())); + return instances_[dev_id].get(); + } + + /** + * Try to allocate `size` npu memory. Only ACL_ERROR_BAD_ALLOC + * or ACL_ERROR_NONE would be returned. + */ + aclError Malloc(void **ptr, size_t size) { + LockGuardPtr lock(mtx_); + if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) { + return ACL_ERROR_BAD_ALLOC; + } + + NPUDeviceGuard guard(dev_id_); + auto result = aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + if (result == ACL_ERROR_NONE) { + if (NeedRecord()) { + cur_size_ += size; + } + STAT_INT_ADD("STAT_npu" + std::to_string(dev_id_) + "_mem_size", size); + return result; + } else { + RaiseNonOutOfMemoryError(&result); + // Non out of memory error would be raised inside + // RaiseNonOutOfMemoryError. Therefore, we can + // return cudaErrorMemoryAllocation directly here. + return ACL_ERROR_BAD_ALLOC; + } + } + + /** + * Free gpu memory. Usually, free is not allowed to raise error. + * If it does raise error, the process should be crashed. + */ + void Free(void *ptr, size_t size) { + NPUDeviceGuard guard(dev_id_); + auto result = aclrtFree(ptr); + PADDLE_ENFORCE_NPU_SUCCESS(result); + if (NeedRecord()) { + std::lock_guard guard(*mtx_); + cur_size_ -= size; + } + STAT_INT_SUB("STAT_npu" + std::to_string(dev_id_) + "_mem_size", size); + } + + bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total) { + { + NPUDeviceGuard guard(dev_id_); + auto result = aclrtGetMemInfo(ACL_HBM_MEM, actual_avail, actual_total); + if (result != ACL_ERROR_NONE) { + *actual_avail = 0; + } + RaiseNonOutOfMemoryError(&result); + } + + if (NeedRecord()) { + std::lock_guard guard(*mtx_); + *avail = std::min(*actual_avail, limit_size_ - cur_size_); + *total = std::min(*actual_total, limit_size_); + return *total < *actual_total; + } else { + *avail = *actual_avail; + *total = *actual_total; + return false; + } + } + + inline bool NeedRecord() const { return limit_size_ != 0; } + + uint64_t RecordedSize() const { + LockGuardPtr lock(mtx_); + return NeedRecord() ? cur_size_ : 0; + } + + uint64_t LimitSize() const { return limit_size_; } + + private: + const int dev_id_; + const uint64_t limit_size_; + uint64_t cur_size_{0}; + + mutable std::unique_ptr mtx_; + + static std::once_flag once_flag_; + static std::vector> instances_; +}; + +std::once_flag RecordedNPUMallocHelper::once_flag_; +std::vector> + RecordedNPUMallocHelper::instances_; + +aclError RecordedNPUMalloc(void **ptr, size_t size, int dev_id) { + return RecordedNPUMallocHelper::Instance(dev_id)->Malloc(ptr, size); +} + +void RecordedNPUFree(void *p, size_t size, int dev_id) { + return RecordedNPUMallocHelper::Instance(dev_id)->Free(p, size); +} + +bool RecordedNPUMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total, int dev_id) { + return RecordedNPUMallocHelper::Instance(dev_id)->GetMemInfo( + avail, total, actual_avail, actual_total); +} + +uint64_t RecordedNPUMallocSize(int dev_id) { + return RecordedNPUMallocHelper::Instance(dev_id)->RecordedSize(); +} + +bool IsNPUMallocRecorded(int dev_id) { + return RecordedNPUMallocHelper::Instance(dev_id)->NeedRecord(); +} + +AclInstance::~AclInstance() {} + +AclInstance &AclInstance::Instance() { + static AclInstance instance; + return instance; +} + +AclInstance::AclInstance() { + PADDLE_ENFORCE_NPU_SUCCESS(aclInit(nullptr)); + VLOG(4) << "Call aclrtSetDevice "; + // NOTE(zhiqiu): why set devices here? + // Because ACL creates a default context which contains 2 streams + // when calling aclrtSetDeviceId, so usually we do not need to + // create contexts explicitly. And, for each device, aclrtSetDeviceId + // need to call parily with aclrtResetDeviceId to destory the default + // context. Here, we use this singleton and static instance to manage + // the devices to make sure they will be resetted before program exit. + devices_ = platform::GetSelectedNPUDevices(); + for (auto it = devices_.rbegin(); it != devices_.rend(); ++it) { + SetNPUDeviceId(*it); + VLOG(4) << "Call aclrtSetDevice " << *it; + } +} + +void AclInstance::Finalize() { + // NOTE(zhiqiu): DO NOT perform finalize in destructor + // to avoid problems caused by destructor order of static + // object. + for (size_t i = 0; i < devices_.size(); ++i) { + auto status = aclrtResetDevice(devices_[i]); + VLOG(4) << "Call aclrtResetDevice " << devices_[i] + << " status = " << status; + } + auto status = aclFinalize(); + VLOG(4) << "Call aclFinalize, status = " << status; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/npu_info.h b/paddle/fluid/platform/npu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..648b18531b2b7f3b5e00b09fed25279c1a68a2d7 --- /dev/null +++ b/paddle/fluid/platform/npu_info.h @@ -0,0 +1,156 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_ASCEND_CL +#include + +#include +#include + +#include "acl/acl.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +//! Get the total number of NPU devices in system. +int GetNPUDeviceCount(); + +//! Get the runtime version of the ith NPU +std::string GetNPURuntimeVersion(int id); +//! Check if this device can access peer or not. +int NPUCanAccessPeer(int src, int dst); + +//! Get the current NPU device id in system. +int GetCurrentNPUDeviceId(); + +//! Get the current NPU stream. +int GetCurrentStream(); + +//! Get a list of device ids from environment variable or use all. +std::vector GetSelectedNPUDevices(); + +//! Set the NPU device id for next execution. +void SetNPUDeviceId(int device_id); + +//! Reset the NPU device id for next execution. +void ResetNPUDeviceId(int device_id); + +//! Get the memory usage of current NPU device. +void NPUMemoryUsage(size_t *available, size_t *total); + +//! Get the available memory to allocate, which is the size of available npu +//! minus reserving. +size_t NPUAvailableMemToAlloc(); + +//! Get the maximum allocation size of current NPU device. +size_t NPUMaxAllocSize(); + +//! Get the initial allocation size of current NPU device. +size_t NPUInitAllocSize(); + +//! Get the re-allocation size of current NPU device. +size_t NPUReallocSize(); + +//! Get the minimum chunk size for NPU buddy allocator. +size_t NPUMinChunkSize(); + +//! Get the maximum chunk size for NPU buddy allocator. +size_t NPUMaxChunkSize(); + +//! Copy memory from address src to dst asynchronously. +void NPUMemcpyAsync(void *dst, const void *src, size_t count, + enum aclrtMemcpyKind kind, aclrtStream stream, + size_t dst_max_count = 0); + +//! Copy memory from address src to dst synchronously. +void NPUMemcpySync(void *dst, const void *src, size_t count, + enum aclrtMemcpyKind kind, size_t dst_max_count = 0); + +//! Set memory dst with value count size asynchronously +void NPUMemsetAsync(void *dst, int value, size_t count, aclrtStream stream, + size_t max_count = 0); + +//! Copy memory from one device to another device asynchronously. +void NPUMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t count, aclrtStream stream, + size_t max_count = 0); + +//! Copy memory from one device to another device synchronously. +void NPUMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t count, size_t max_count = 0); + +//! Blocks until stream has completed all operations. +void NPUStreamSync(aclrtStream stream); + +//! aclrtMalloc with recorded info +aclError RecordedNPUMalloc(void **ptr, size_t size, int dev_id); + +//! aclrtFree with recorded info +void RecordedNPUFree(void *p, size_t size, int dev_id); + +//! Get available and total gpu memory with considering limitation +bool RecordedNPUMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, + size_t *actual_total, int dev_id); + +//! Get recorded actrtMalloc size. If record is disabled, return 0. +uint64_t RecordedNPUMallocSize(int dev_id); + +bool IsNPUMallocRecorded(int dev_id); + +class NPUDeviceGuard { + public: + explicit inline NPUDeviceGuard(int dev_id) { + int prev_id = platform::GetCurrentNPUDeviceId(); + if (prev_id != dev_id) { + prev_id_ = prev_id; + platform::SetNPUDeviceId(dev_id); + } + } + + inline ~NPUDeviceGuard() { + if (prev_id_ != -1) { + platform::SetNPUDeviceId(prev_id_); + } + } + + NPUDeviceGuard(const NPUDeviceGuard &o) = delete; + NPUDeviceGuard &operator=(const NPUDeviceGuard &o) = delete; + + private: + int prev_id_{-1}; +}; + +class AclInstance { + public: + // NOTE(zhiiu): Commonly, exception in destructor is not recommended, so + // no PADDLE_ENFORCE here, call acl API directly. + ~AclInstance(); + AclInstance(const AclInstance &o) = delete; + const AclInstance &operator=(const AclInstance &o) = delete; + static AclInstance &Instance(); + void Finalize(); + + private: + // forbid calling default constructor + AclInstance(); + std::vector devices_; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/place.cc b/paddle/fluid/platform/place.cc index b80d2fd1632cd82c231fae724fc4d754b8fed0fc..1cc9fd9fe76341cd495a3580cddbff65f5b0e208 100644 --- a/paddle/fluid/platform/place.cc +++ b/paddle/fluid/platform/place.cc @@ -33,6 +33,7 @@ class PlacePrinter : public boost::static_visitor<> { os_ << "CUDAPlace(" << p.device << ")"; } void operator()(const XPUPlace &p) { os_ << "XPUPlace(" << p.device << ")"; } + void operator()(const NPUPlace &p) { os_ << "NPUPlace(" << p.device << ")"; } void operator()(const CUDAPinnedPlace &p) { os_ << "CUDAPinnedPlace"; } private: @@ -49,6 +50,10 @@ bool is_xpu_place(const Place &p) { return boost::apply_visitor(IsXPUPlace(), p); } +bool is_npu_place(const Place &p) { + return boost::apply_visitor(IsNPUPlace(), p); +} + bool is_cpu_place(const Place &p) { return boost::apply_visitor(IsCPUPlace(), p); } @@ -67,6 +72,8 @@ bool is_same_place(const Place &p1, const Place &p2) { return true; } else if (is_xpu_place(p1)) { return BOOST_GET_CONST(XPUPlace, p1) == BOOST_GET_CONST(XPUPlace, p2); + } else if (is_npu_place(p1)) { + return BOOST_GET_CONST(NPUPlace, p1) == BOOST_GET_CONST(NPUPlace, p2); } else { return BOOST_GET_CONST(CUDAPlace, p1) == BOOST_GET_CONST(CUDAPlace, p2); } diff --git a/paddle/fluid/platform/place.h b/paddle/fluid/platform/place.h index e11ca4159e07e927b11cf1e0c3f6c638b71c4c84..f20fac477d0ec4ef40a3544476e223b6ad97fffa 100644 --- a/paddle/fluid/platform/place.h +++ b/paddle/fluid/platform/place.h @@ -72,16 +72,31 @@ struct XPUPlace { int device; }; +struct NPUPlace { + NPUPlace() : NPUPlace(0) {} + explicit NPUPlace(int d) : device(d) {} + + inline int GetDeviceId() const { return device; } + // needed for variant equality comparison + inline bool operator==(const NPUPlace &o) const { return device == o.device; } + inline bool operator!=(const NPUPlace &o) const { return !(*this == o); } + inline bool operator<(const NPUPlace &o) const { return device < o.device; } + + int device; +}; + struct IsCUDAPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const XPUPlace &) const { return false; } - bool operator()(const CUDAPlace &gpu) const { return true; } + bool operator()(const NPUPlace &) const { return false; } + bool operator()(const CUDAPlace &) const { return true; } bool operator()(const CUDAPinnedPlace &) const { return false; } }; struct IsCPUPlace : public boost::static_visitor { - bool operator()(const CPUPlace &cpu) const { return true; } + bool operator()(const CPUPlace &) const { return true; } bool operator()(const XPUPlace &) const { return false; } + bool operator()(const NPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } }; @@ -89,27 +104,38 @@ struct IsCPUPlace : public boost::static_visitor { struct IsCUDAPinnedPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const XPUPlace &) const { return false; } + bool operator()(const NPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &cuda_pinned) const { return true; } }; struct IsXPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } - bool operator()(const XPUPlace &xpu) const { return true; } + bool operator()(const XPUPlace &) const { return true; } + bool operator()(const NPUPlace &) const { return false; } bool operator()(const CUDAPlace &) const { return false; } bool operator()(const CUDAPinnedPlace &) const { return false; } }; -class Place - : public boost::variant { +struct IsNPUPlace : public boost::static_visitor { + bool operator()(const CPUPlace &) const { return false; } + bool operator()(const XPUPlace &) const { return false; } + bool operator()(const NPUPlace &) const { return true; } + bool operator()(const CUDAPlace &) const { return false; } + bool operator()(const CUDAPinnedPlace &) const { return false; } +}; + +class Place : public boost::variant { private: using PlaceBase = - boost::variant; + boost::variant; public: Place() = default; Place(const CPUPlace &cpu_place) : PlaceBase(cpu_place) {} // NOLINT Place(const XPUPlace &xpu_place) : PlaceBase(xpu_place) {} // NOLINT + Place(const NPUPlace &npu_place) : PlaceBase(npu_place) {} // NOLINT Place(const CUDAPlace &cuda_place) : PlaceBase(cuda_place) {} // NOLINT Place(const CUDAPinnedPlace &cuda_pinned_place) // NOLINT : PlaceBase(cuda_pinned_place) {} @@ -126,6 +152,7 @@ using PlaceList = std::vector; bool is_gpu_place(const Place &); bool is_xpu_place(const Place &); +bool is_npu_place(const Place &); bool is_cpu_place(const Place &); bool is_cuda_pinned_place(const Place &); bool places_are_same_class(const Place &, const Place &); @@ -153,6 +180,16 @@ struct PlaceVisitorWrapper #endif } + typename Visitor::result_type operator()(const NPUPlace &npu) const { +#ifdef PADDLE_WITH_ASCEND + return visitor_(npu); +#else + PADDLE_THROW(platform::errors::Unavailable( + "Paddle is not compiled with NPU. Cannot visit npu device")); + return typename Visitor::result_type(); +#endif + } + typename Visitor::result_type operator()(const CUDAPlace &cuda) const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return visitor_(cuda); diff --git a/paddle/fluid/platform/stream/CMakeLists.txt b/paddle/fluid/platform/stream/CMakeLists.txt index c0595eb415da6c4c6c543dd713100714dca19dc2..e1e3e49ce9cbc04298cdb47e29e060ebffc88ba1 100644 --- a/paddle/fluid/platform/stream/CMakeLists.txt +++ b/paddle/fluid/platform/stream/CMakeLists.txt @@ -1,3 +1,7 @@ IF(WITH_GPU OR WITH_ROCM) cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost) ENDIF() + +IF(WITH_ASCEND_CL) +cc_library(npu_stream SRCS npu_stream.cc DEPS enforce boost stream_callback_manager) +ENDIF() diff --git a/paddle/fluid/platform/stream/cuda_stream.cc b/paddle/fluid/platform/stream/cuda_stream.cc index fc51a08c2aa24873de58929031dc44f18ac4509f..6c6a47fadb5f40c6dbb54ea173b80b55d08ba8fc 100644 --- a/paddle/fluid/platform/stream/cuda_stream.cc +++ b/paddle/fluid/platform/stream/cuda_stream.cc @@ -49,8 +49,8 @@ bool CUDAStream::Init(const Place& place, const Priority& priority) { cudaStreamCreateWithPriority(&stream_, kDefaultFlag, 0)); #endif } - callback_manager_.reset(new StreamCallbackManager(stream_)); - VLOG(3) << "CUDAStream Init stream: " << stream_ + callback_manager_.reset(new StreamCallbackManager(stream_)); + VLOG(3) << "GPUStream Init stream: " << stream_ << ", priority: " << static_cast(priority); return true; } diff --git a/paddle/fluid/platform/stream/cuda_stream.h b/paddle/fluid/platform/stream/cuda_stream.h index d9375492519d8c26c487326e5325efa0ea961de0..46bbe94b080f965aed5ba08423512777f84e3ec0 100644 --- a/paddle/fluid/platform/stream/cuda_stream.h +++ b/paddle/fluid/platform/stream/cuda_stream.h @@ -101,7 +101,7 @@ class CUDAStream final { cudaStream_t stream_{nullptr}; #endif Priority priority_{Priority::kNormal}; - std::unique_ptr callback_manager_; + std::unique_ptr> callback_manager_; DISABLE_COPY_AND_ASSIGN(CUDAStream); }; diff --git a/paddle/fluid/platform/stream/npu_stream.cc b/paddle/fluid/platform/stream/npu_stream.cc new file mode 100644 index 0000000000000000000000000000000000000000..2664ac7194bf2b5996a7237fcee57ba96fa48523 --- /dev/null +++ b/paddle/fluid/platform/stream/npu_stream.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/stream/npu_stream.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/npu_info.h" + +namespace paddle { +namespace platform { +namespace stream { + +bool NPUStream::Init(const Place& place) { + PADDLE_ENFORCE_EQ(is_npu_place(place), true, + platform::errors::InvalidArgument( + "NPU stream must be created using npu place.")); + place_ = place; + NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateStream(&stream_)); + callback_manager_.reset(new StreamCallbackManager(stream_)); + VLOG(3) << "NPUStream Init stream: " << stream_; + return true; +} + +void NPUStream::Destroy() { + NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device); + Wait(); + WaitCallback(); + if (stream_) { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyStream(stream_)); + } + stream_ = nullptr; +} + +void NPUStream::Wait() const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_)); +} + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream/npu_stream.h b/paddle/fluid/platform/stream/npu_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..7e5d574acecf54f8ecf2476db7bbe177f34a9196 --- /dev/null +++ b/paddle/fluid/platform/stream/npu_stream.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/npu_info.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/stream_callback_manager.h" + +namespace paddle { +namespace platform { +namespace stream { + +#ifdef PADDLE_WITH_ASCEND_CL + +class NPUStream final { + public: + NPUStream() = default; + explicit NPUStream(const Place& place) { Init(place); } + virtual ~NPUStream() { Destroy(); } + + bool Init(const Place& place); + + template + void AddCallback(Callback&& callback) const { + callback_manager_->AddCallback(callback); + } + + template + void RecordEvent(aclrtEvent ev, Callback callback) const { + callback(); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_)); + } + + void RecordEvent(aclrtEvent ev) const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_)); + } + + void WaitEvent(aclrtEvent ev) const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(stream_, ev)); + } + + void Wait() const; + void WaitCallback() const { callback_manager_->Wait(); } + + aclrtStream raw_stream() const { return stream_; } + void Destroy(); + + private: + Place place_; + aclrtStream stream_{nullptr}; + std::unique_ptr> callback_manager_; + + DISABLE_COPY_AND_ASSIGN(NPUStream); +}; + +#endif + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index d6b106dc582d51d4f3339dfdac4f782ec1942fa5..287c8fc37e005a4bee8092b51cc3ec20cb814540 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -21,11 +21,18 @@ namespace platform { #ifdef PADDLE_WITH_HIP static void StreamCallbackFunc(gpuStream_t stream, gpuError_t status, void *user_data) -#elif CUDA_VERSION >= 10000 -static void CUDART_CB StreamCallbackFunc(void *user_data) +#endif +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10000 + static void CUDART_CB StreamCallbackFunc(void *user_data) #else -static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, - cudaError_t status, void *user_data) + static void CUDART_CB + StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void *user_data) +#endif +#endif + +#if PADDLE_WITH_ASCEND_CL + static void StreamCallbackFunc(void *user_data) #endif { std::unique_ptr> func( @@ -33,10 +40,13 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, (*func)(); } -StreamCallbackManager::StreamCallbackManager(const gpuStream_t stream) +template +StreamCallbackManager::StreamCallbackManager(const Stream stream) : stream_(stream), thread_pool_(1) {} -void StreamCallbackManager::AddCallback(std::function callback) const { +template +void StreamCallbackManager::AddCallback( + std::function callback) const { auto *callback_func = new std::function(std::move(callback)); auto *func = new std::function([this, callback_func] { std::lock_guard lock(mtx_); @@ -45,23 +55,37 @@ void StreamCallbackManager::AddCallback(std::function callback) const { (*callback_func)(); }); }); + #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_CUDA_SUCCESS( hipStreamAddCallback(stream_, StreamCallbackFunc, func, 0)); -#elif CUDA_VERSION >= 10000 +#endif +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10000 PADDLE_ENFORCE_CUDA_SUCCESS( cudaLaunchHostFunc(stream_, StreamCallbackFunc, func)); #else PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0)); #endif +#endif + +#if PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func, + ACL_CALLBACK_BLOCK, stream_)); +#endif } -void StreamCallbackManager::Wait() const { +template +void StreamCallbackManager::Wait() const { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_)); -#else +#endif +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_)); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_)); #endif { std::lock_guard lock(mtx_); @@ -71,5 +95,15 @@ void StreamCallbackManager::Wait() const { } } +#ifdef PADDLE_WITH_CUDA +template struct StreamCallbackManager; +#endif +#ifdef PADDLE_WITH_HIP +template struct StreamCallbackManager; +#endif +#ifdef PADDLE_WITH_ASCEND_CL +template struct StreamCallbackManager; +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h index 56e8f83b5a51c1d04a19dc2c0a85a312ad461e3f..1b960f188ec3045d7362c5f4fb850b7ba6a0a85e 100644 --- a/paddle/fluid/platform/stream_callback_manager.h +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -37,9 +37,10 @@ namespace platform { // NOTE(zjl): clean StreamCallbackManager to make compilation faster // Make StreamCallbackManager thread-safe +template class StreamCallbackManager { public: - explicit StreamCallbackManager(const gpuStream_t stream); + explicit StreamCallbackManager(const Stream stream); ~StreamCallbackManager() = default; @@ -48,7 +49,7 @@ class StreamCallbackManager { void Wait() const; private: - const gpuStream_t stream_; + const Stream stream_; mutable ::ThreadPool thread_pool_; mutable std::mutex mtx_; mutable std::future last_future_; diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index e8ba16398d2b00041ab7706c2cc3334a9c113186..bc8d1e5b40585dd8a44255b33c835be12c473cec 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -88,10 +88,17 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb); // others DECLARE_bool(sync_nccl_allreduce); #endif + #ifdef PADDLE_WITH_XPU // device management DECLARE_string(selected_xpus); #endif + +#ifdef PADDLE_WITH_ASCEND_CL +// device management +DECLARE_string(selected_npus); +#endif + #ifdef PADDLE_WITH_DISTRIBUTE DECLARE_int32(rpc_send_thread_num); DECLARE_int32(rpc_get_thread_num); @@ -374,6 +381,11 @@ static void RegisterGlobalVarGetterSetter() { #ifdef PADDLE_WITH_XPU REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_xpus); #endif + +#ifdef PADDLE_WITH_ASCEND_CL + REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_npus); +#endif + #ifdef PADDLE_WITH_DITRIBUTE REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_send_thread_num, FLAGS_rpc_get_thread_num, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 215c81a00e89a8a267b203faad30fcbd9cce00ab..428c7c2420b9878ca9ed1288dfb8ceea06cf97a8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -107,6 +107,10 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" #endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/platform/npu_info.h" +#endif + #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu_info.h" #endif @@ -163,6 +167,14 @@ bool IsCompiledWithXPU() { #endif } +bool IsCompiledWithNPU() { +#ifndef PADDLE_WITH_ASCEND_CL + return false; +#else + return true; +#endif +} + bool IsCompiledWithMKLDNN() { #ifndef PADDLE_WITH_MKLDNN return false; @@ -569,6 +581,11 @@ PYBIND11_MODULE(core_noavx, m) { make_ddim(x_dim), make_ddim(y_dim), -1)); }); +#ifdef PADDLE_WITH_ASCEND_CL + m.def("_npu_finalize", + []() { platform::AclInstance::Instance().Finalize(); }); +#endif + m.def( "_append_python_callable_object_and_return_id", [](py::object py_obj) -> size_t { @@ -641,6 +658,10 @@ PYBIND11_MODULE(core_noavx, m) { [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) + .def("_alloc_float", + [](framework::Tensor &self, paddle::platform::NPUPlace &place) { + self.mutable_data(place); + }) .def("_alloc_double", [](framework::Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); @@ -688,12 +709,19 @@ PYBIND11_MODULE(core_noavx, m) { return reinterpret_cast(self.mutable_data(place, type)); }) .def("_clear", &framework::Tensor::clear) + .def("_mutable_data", + [](framework::Tensor &self, paddle::platform::NPUPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place"), py::arg("zero_copy") = false) .def("set", SetTensorFromPyArray, py::arg("array"), py::arg("place"), py::arg("zero_copy") = false, R"DOC( @@ -701,7 +729,7 @@ PYBIND11_MODULE(core_noavx, m) { Args: lod (numpy.ndarray): The data to set. - place (CPUPlace|CUDAPlace|XPUPlace|CUDAPinnedPlace): The place where the + place (CPUPlace|CUDAPlace|XPUPlace|CUDAPinnedPlace|NPUPlace): The place where the LoDTensor is to be set. zero_copy (bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace. Default: False. @@ -1429,6 +1457,18 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::XPUDeviceContext(place); #endif }) + .def_static("create", + [](paddle::platform::NPUPlace& place) + -> paddle::platform::DeviceContext* { +#ifndef PADDLE_WITH_ASCEND_CL + PADDLE_THROW( + platform::errors::PermissionDenied( + "Cannot use NPUPlace in CPU/GPU/XPU version, " + "Please recompile or reinstall Paddle with NPU support.")); +#else + return new paddle::platform::NPUDeviceContext(place); +#endif + }) .def_static("create", [](paddle::platform::CUDAPlace& place) -> paddle::platform::DeviceContext* { @@ -1529,6 +1569,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_get_device_id", @@ -1598,6 +1639,7 @@ All parameter, weight, gradient are variables in Paddle. #ifdef PADDLE_WITH_XPU m.def("get_xpu_device_count", platform::GetXPUDeviceCount); #endif + py::class_(m, "CPUPlace", R"DOC( CPUPlace is a descriptor of a device. It represents a CPU device on which a tensor will be allocated and a model will run. @@ -1613,6 +1655,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", @@ -1650,6 +1693,8 @@ All parameter, weight, gradient are variables in Paddle. &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", @@ -1657,6 +1702,65 @@ All parameter, weight, gradient are variables in Paddle. .def("__repr__", string::to_string) .def("__str__", string::to_string); + // NPUPlace + py::class_(m, "NPUPlace", R"DOC( + NPUPlace is a descriptor of a device. + It represents a NPU device on which a tensor will be allocated and a model will run. + + Examples: + .. code-block:: python + import paddle + npu_place = paddle.NPUPlace(0) + + )DOC") + .def("__init__", + [](platform::NPUPlace &self, int dev_id) { +#ifdef PADDLE_WITH_ASCEND_CL + if (UNLIKELY(dev_id < 0)) { + LOG(ERROR) << string::Sprintf( + "Invalid NPUPlace(%d), device id must be 0 or " + "positive integer", + dev_id); + std::exit(-1); + } + if (UNLIKELY(dev_id >= platform::GetNPUDeviceCount())) { + if (platform::GetNPUDeviceCount() == 0) { + LOG(ERROR) << "Cannot use NPU because there is no NPU " + "detected on your " + "machine."; + std::exit(-1); + } else { + LOG(ERROR) << string::Sprintf( + "Invalid NPUPlace(%d), must inside [0, %d), because NPU " + "number on your machine is %d", + dev_id, platform::GetNPUDeviceCount(), + platform::GetNPUDeviceCount()); + std::exit(-1); + } + } + new (&self) platform::NPUPlace(dev_id); +#else + LOG(ERROR) << string::Sprintf( + "Cannot use NPU because you have installed CPU/GPU version " + "PaddlePaddle.\n" + "If you want to use NPU, please try to install NPU version " + "PaddlePaddle by: pip install paddlepaddle-xpu\n" + "If you only have CPU, please change NPUPlace(%d) to be " + "CPUPlace().\n", + dev_id); + std::exit(-1); +#endif + }) + .def("_type", &PlaceIndex) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("__str__", string::to_string); + py::class_(m, "Place") .def(py::init<>()) .def("_type", &PlaceIndex) @@ -1664,6 +1768,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) @@ -1671,6 +1776,8 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return platform::is_cpu_place(self); }) .def("is_xpu_place", [](platform::Place &self) { return platform::is_xpu_place(self); }) + .def("is_npu_place", + [](platform::Place &self) { return platform::is_npu_place(self); }) .def("is_cuda_pinned_place", [](platform::Place &self) { return platform::is_cuda_pinned_place(self); @@ -1683,6 +1790,10 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self) { return BOOST_GET_CONST(platform::XPUPlace, self).device; }) + .def("npu_device_id", + [](platform::Place &self) { + return BOOST_GET_CONST(platform::NPUPlace, self).device; + }) .def("set_place", [](platform::Place &self, const platform::Place &other) { self = other; }) .def("set_place", @@ -1702,6 +1813,10 @@ All parameter, weight, gradient are variables in Paddle. const platform::CUDAPinnedPlace &cuda_pinned_place) { self = cuda_pinned_place; }) + .def("set_place", + [](platform::Place &self, const platform::NPUPlace &npu_place) { + self = npu_place; + }) .def("__repr__", string::to_string) .def("__str__", string::to_string); @@ -1726,6 +1841,9 @@ All parameter, weight, gradient are variables in Paddle. .def("run", [](OperatorBase &self, const Scope &scope, const platform::XPUPlace &place) { self.Run(scope, place); }) + .def("run", + [](OperatorBase &self, const Scope &scope, + const platform::NPUPlace &place) { self.Run(scope, place); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CUDAPlace &place) { self.Run(scope, place); }) @@ -1828,6 +1946,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_rocm", IsCompiledWithROCM); + m.def("is_compiled_with_npu", IsCompiledWithNPU); m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("supports_bfloat16", SupportsBfloat16); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 5f25217007017cff963ffb6a7a701f0618b44d79..ab1dd8a180b5b677d2e59523ab4365cb11bf6cd8 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -294,6 +294,22 @@ void SetTensorFromPyArrayT( PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use XPUPlace in CPU/GPU version, " "Please recompile or reinstall Paddle with XPU support.")); +#endif + } else if (paddle::platform::is_npu_place(place)) { +#ifdef PADDLE_WITH_ASCEND_CL + platform::Place tmp_place = place; + platform::NPUDeviceGuard guard( + BOOST_GET_CONST(platform::NPUPlace, tmp_place).device); + auto dst = self->mutable_data(place); + platform::NPUMemcpySync(dst, array.data(), array.nbytes(), + ACL_MEMCPY_HOST_TO_DEVICE); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &ctx = *pool.Get(place); + ctx.Wait(); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use NPUPlace in CPU/GPU/XPU version. " + "Please recompile or reinstall Paddle with NPU support.")); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 38ed76a87cd3e46145d4a1a5e679174a41a4ee86..a886f7a0298373d170f72e2cee9c973d2f931941 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/platform/init.h" +#include "paddle/fluid/platform/npu_info.h" int main(int argc, char** argv) { paddle::memory::allocation::UseAllocatorStrategyGFlag(); @@ -38,11 +39,13 @@ int main(int argc, char** argv) { } #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_ASCEND_CL) envs.push_back("fraction_of_gpu_memory_to_use"); envs.push_back("initial_gpu_memory_in_mb"); envs.push_back("reallocate_gpu_memory_in_mb"); envs.push_back("allocator_strategy"); + envs.push_back("selected_gpus"); #elif __clang__ envs.push_back("use_mkldnn"); envs.push_back("initial_cpu_memory_in_mb"); @@ -61,6 +64,10 @@ int main(int argc, char** argv) { undefok.push_back("initial_cpu_memory_in_mb"); #endif +#if defined(PADDLE_WITH_ASCEND_CL) + envs.push_back("selected_npus"); +#endif + char* env_str = nullptr; if (envs.size() > 0) { std::string env_string = "--tryfromenv="; @@ -93,6 +100,10 @@ int main(int argc, char** argv) { int ret = RUN_ALL_TESTS(); +#ifdef PADDLE_WITH_ASCEND_CL + paddle::platform::AclInstance::Instance().Finalize(); +#endif + if (env_str) free(env_str); if (undefok_str) free(undefok_str); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 02725751cb6694a774cd224c53a0bea6b8dc680b..17bf2d544f31dfd749d7583dc75d5fa95b96a0bf 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -238,6 +238,7 @@ from .framework import ParamAttr #DEFINE_ALIAS from .framework import create_parameter #DEFINE_ALIAS from .framework import CPUPlace #DEFINE_ALIAS from .framework import CUDAPlace #DEFINE_ALIAS +from .framework import NPUPlace #DEFINE_ALIAS from .framework import CUDAPinnedPlace #DEFINE_ALIAS from .framework import grad #DEFINE_ALIAS @@ -262,6 +263,7 @@ from .device import set_device from .device import get_device from .device import is_compiled_with_cuda #DEFINE_ALIAS from .device import is_compiled_with_xpu +from .device import is_compiled_with_npu from .device import XPUPlace # from .tensor.tensor import Tensor #DEFINE_ALIAS # from .tensor.tensor import LoDTensor #DEFINE_ALIAS diff --git a/python/paddle/device.py b/python/paddle/device.py index 81b1dfcc745a4adadc68e3391c4b07a0d6cb4b0a..d5e4406454b1eb1ea9e1b681c0485496179e5869 100644 --- a/python/paddle/device.py +++ b/python/paddle/device.py @@ -32,12 +32,28 @@ __all__ = [ # 'cuda_places', # 'CUDAPinnedPlace', # 'CUDAPlace', - 'is_compiled_with_cuda' + 'is_compiled_with_cuda', + 'is_compiled_with_npu' ] _cudnn_version = None +def is_compiled_with_npu(): + """ + Whether this whl package can be used to run the model on NPU. + + Returns (bool): `True` if NPU is supported, otherwise `False`. + + Examples: + .. code-block:: python + + import paddle + support_npu = paddle.is_compiled_with_npu() + """ + return core.is_compiled_with_npu() + + def is_compiled_with_xpu(): """ Whether paddle was built with WITH_XPU=ON to support Baidu Kunlun @@ -165,6 +181,7 @@ def set_device(device): device_id = device_info_list[1] device_id = int(device_id) place = core.XPUPlace(device_id) + framework._set_expected_place(place) return place diff --git a/python/paddle/distributed/fleet/ascend_utils.py b/python/paddle/distributed/fleet/ascend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23e812041c8b2d2230539bdf8e02252d720ba3d0 --- /dev/null +++ b/python/paddle/distributed/fleet/ascend_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import paddle +from paddle.distributed.fleet.launch_utils import get_cluster, logger, get_host_name_ip, DeviceMode + + +def _get_ascend_rankfile(rank_table_file_path): + """ + Args: + rank_table_file_path: ascend npu rank file json + { + "status": "completed", + "version": "1.0", + "server_count": "2", + "server_list": [ + { + "server_id": "192.168.24.217", + "device": [ + { + "device_id": "0", + "device_ip": "192.1.184.23", + "rank_id": "0" + }, + { + "device_id": "1", + "device_ip": "192.2.21.93", + "rank_id": "1" + } + ] + }, + { + "server_id": "192.168.26.177", + "device": [ + { + "device_id": "0", + "device_ip": "192.1.94.132", + "rank_id": "2" + }, + { + "device_id": "1", + "device_ip": "192.2.94.30", + "rank_id": "3" + } + ] + } + ] + } + + Returns: + node_ips: node ip list + device_count: number of npu per machine + + """ + json_data = None + with open(rank_table_file_path) as json_file: + json_data = json.load(json_file) + + node_ips = [] + device_count = 0 + server_list = json_data['server_list'] + for server in server_list: + node_ips.append(server['server_id']) + device_list = server['device'] + device_count = len(device_list) + + return node_ips, device_count + + +def get_cloud_cluster(rank_table_file=None, + device_mode=DeviceMode.ASCEND_NPU, + devices_per_proc=None, + start_port=6070): + """ + Args: + rank_table_file: string, ascend npu rank file path + device_mode: DeviceMode(Int) + devices_per_proc:list + start_port: the start port of current runtime env + """ + if rank_table_file: + # multi trainers + node_ips, device_count = _get_ascend_rankfile(rank_table_file) + node_index = os.environ.get("PADDLE_TRAINER_ID") + node_ip = None + if node_index is None: + _, node_ip = get_host_name_ip() + else: + node_ip = node_ips[int(node_index)] + + assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \ + % (node_ip, node_ips) + else: + # single trainer (single ascend card) + node_ips = ["127.0.0.1"] + node_ip = node_ips[0] + device_count = 1 + devices_per_proc = None + + if devices_per_proc is None: + devices_per_proc = [str(x) for x in range(device_count)] + + free_ports = [ + x for x in range(start_port, start_port + len(devices_per_proc)) + ] + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + + return get_cluster(node_ips, node_ip, trainer_endpoints, device_mode, + devices_per_proc) diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index d6f4227a92380a3dd05e30c25f00c4a3fda428b1..bd5b67005ba92770ffebd466e3516e55ab7d2141 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -73,6 +73,7 @@ from paddle.distributed.fleet import launch_utils # TODO(danleifeng): Don't import * from a module from paddle.distributed.fleet.launch_utils import * import paddle.distributed.fleet.cloud_utils as cloud_utils +import paddle.distributed.fleet.ascend_utils as ascend_utils def _print_arguments(args): @@ -120,7 +121,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra default=None, help="It's for ascend npu training." "For example:" - "--ascend_npus=\"0,1,2,3\" will launch four training processes each bound to one gpu." + "--ascend_npus=\"0,1,2,3\" will launch four training processes each bound to one npu." ) if fluid.core.is_compiled_with_cuda(): @@ -237,6 +238,13 @@ def launch_collective(args): cluster, pod = cloud_utils.get_cloud_cluster( args.ips, device_mode, devices_per_proc, start_port) logger.debug("get cluster from cloud:{}".format(cluster)) + elif device_mode == DeviceMode.ASCEND_NPU: + # for ascend + cluster, pod = ascend_utils.get_cloud_cluster( + rank_table_file=os.getenv("RANK_TABLE_FILE", None), + device_mode=device_mode, + devices_per_proc=devices_per_proc, + start_port=start_port) else: # trainers_num = 1 or not use paddlecloud ips="a,b" cluster, pod = get_cluster_from_args(args, device_mode, diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 2d2807bce28156d4c49dc9124ca81dc3c59cce9e..9f6c186b353399fdbc4c1310337d43d1d314b681 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -593,8 +593,8 @@ def get_ascend_npus(npus): if npus is None: count = fluid.core.NPUDevice.get_device_count() if count <= 0: - return ret - ret = [x for x in range(count)] + return None + ret = [str(x) for x in range(count)] else: ret = [x.strip() for x in npus.split(',')] return ret diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 978899604eaf8c2ee45c03f866f2d5a081a7e502..824225fd776d1363d79e2218959507df8668bcee 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer): parameter_list=None, no_grad_set=None, auto_dp=False, - rank_table_file=None): + rank_table_file=None, + precision_mode="must_keep_origin_dtype"): minimized = None if self.inner_opt: minimized = self.inner_opt.minimize( @@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer): config = { "ge.exec.deviceId": str(fleet.local_device_ids()), "ge.graphRunMode": "1", - "ge.exec.precision_mode": "must_keep_origin_dtype", + "ge.exec.precision_mode": precision_mode, } # if multi trainers if rank_table_file and fleet.world_size() > 1: diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index f2ecaf4843829e231b50f160511681a9e2280405..19b5e910db29937845a48ba790b7939db4b250bb 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -200,7 +200,8 @@ class AscendParserBase(object): def _accumulated_op_id(self): global global_cnt global_cnt += 1 - return "." + str(global_cnt) + name = "." + str(global_cnt) + return name def _create_ge_tensor(self, shape, dtype, value): tensor_desc = core.GETensorDesc( @@ -1622,10 +1623,14 @@ class MulGradParser(AscendParserBase): "unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input("x", y).set_attr_vec_int32("axes", [0]) + y_stack = core.GEOperatorFactory.create_operator( + "stack" + self._accumulated_op_id(), + "TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32( + "axis", 0).set_attr_int32("tiles", shape_out_grad[0]) x_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(), "BatchMatMul").set_input("x1", out_grad).set_input( - "x2", y_unsqueeze).set_attr_bool( + "x2", y_stack).set_attr_bool( "adj_x1", False).set_attr_bool("adj_x2", True) y_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(), diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index ae3418687853b948eea36e3bb1bd4c8e80d29901..6dd1478dc1f45fb388fbc4ca978db30522e058d4 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -68,7 +68,8 @@ from .input import embedding, one_hot from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder -from .core import LoDTensor, LoDTensorArray, CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, Scope, _Scope +from .core import LoDTensor, LoDTensorArray, Scope, _Scope +from .core import CPUPlace, XPUPlace, CUDAPlace, CUDAPinnedPlace, NPUPlace from .incubate import fleet from .incubate import data_generator from .transpiler import DistributeTranspiler, \ @@ -124,6 +125,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'XPUPlace', 'CUDAPlace', 'CUDAPinnedPlace', + 'NPUPlace', 'Tensor', 'ParamAttr', 'WeightNormParamAttr', @@ -232,6 +234,16 @@ def __bootstrap__(): 'gpu_memory_limit_mb', 'conv2d_disable_cudnn', ] + + if core.is_compiled_with_npu(): + read_env_flags += [ + 'selected_npus', + 'fraction_of_gpu_memory_to_use', + 'initial_gpu_memory_in_mb', + 'reallocate_gpu_memory_in_mb', + 'gpu_memory_limit_mb', + ] + core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)]) core.init_glog(sys.argv[0]) # don't init_p2p when in unittest to save time. diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index da326ec074c1d960a74b591d66e1f78f8ead5bc2..9c85cc6cd5db692afcad25685a471e35a60e317c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1213,6 +1213,7 @@ class Executor(object): # In distributed training, the compiled program is saved in Program._graph has_compiled_graph = isinstance(program._graph, compiler.CompiledProgram) + if has_compiled_graph: program._graph._compile(scope, self.place) # _graph in program does not support inference since the _graph is optimized diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d5c01d20a918247524910ac820d9e0c5a1e9e885..499f0873dc3bb15fc9eb945ab41f992d435bba2e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6201,7 +6201,7 @@ def _get_paddle_place(place): if place is None: return place if isinstance(place, (core.Place, core.XPUPlace, core.CPUPlace, - core.CUDAPinnedPlace, core.CUDAPlace)): + core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace)): return place if not isinstance(place, str): @@ -6211,9 +6211,11 @@ def _get_paddle_place(place): place = place.lower() if (place == "cpu"): return core.CPUPlace() + if (place == "device"): return core.Place() + # GPU avaliable_gpu_place = re.match(r'gpu:\d+', place) if place == "gpu_pinned" or place == "gpu" or avaliable_gpu_place: if not core.is_compiled_with_cuda(): @@ -6229,6 +6231,8 @@ def _get_paddle_place(place): device_id = place_info_list[1] device_id = int(device_id) return core.CUDAPlace(device_id) + + # XPU avaliable_xpu_place = re.match(r'xpu:\d+', place) if avaliable_xpu_place: if not core.is_compiled_with_xpu(): @@ -6239,9 +6243,22 @@ def _get_paddle_place(place): device_id = place_info_list[1] device_id = int(device_id) return core.XPUPlace(device_id) + + # NPU + avaliable_npu_place = re.match(r'npu:\d+', place) + if avaliable_npu_place: + if not core.is_compiled_with_npu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " \ + "not compiled with NPU".format(avaliable_npu_place)) + place_info_list = place.split(':', 1) + device_id = place_info_list[1] + device_id = int(device_id) + return core.NPUPlace(device_id) + raise ValueError( - "paddle support CPUPlace, CUDAPlace,CUDAPinnedPlace and XPUPlace, Please check your Place Input" - ) + "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace and NPUPlace, but received {}.". + format(place)) def _get_paddle_place_list(places): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 792a2d32326da2db6c638802b3507eccd84f3e56..e8669fd295162dc9d7ef4278e6f6cd3757845dd4 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -625,6 +625,10 @@ if (WITH_XPU) add_subdirectory(xpu) endif() +if (WITH_ASCEND_CL) + add_subdirectory(npu) +endif() + if (WITH_MKLDNN) add_subdirectory(mkldnn) endif() diff --git a/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f71e04c09aa38b8cf7b3a167b84d4dc0e6cc3ec7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_add_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_add_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..47da4fdb23ec49924fbfb1b5cc4b02e2355d287e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_add_op_npu.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, _set_use_system_allocator +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestElementwiseAddOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "elementwise_add" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float32 + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + # TODO(ascendrc): Test grad op after it is implemented. + # def test_check_grad_normal(self): + # self.check_grad_with_place( + # self.place, ['X', 'Y'], + # 'Out', + # max_relative_error=0.006, + # check_dygraph=False) + # + # def test_check_grad_ingore_x(self): + # self.check_grad_with_place( + # self.place, ['Y'], + # 'Out', + # no_grad_set=set("X"), + # max_relative_error=0.006, + # check_dygraph=False) + # + # def test_check_grad_ingore_y(self): + # self.check_grad_with_place( + # self.place, ['X'], + # 'Out', + # no_grad_set=set("Y"), + # max_relative_error=0.006,check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestAddAPI(unittest.TestCase): + def test_name(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[2, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[2, 3], dtype='float32') + + y_1 = paddle.add(x, y, name='add_res') + self.assertEqual(('add_res' in y_1.name), True) + + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + + x_np = np.array([2, 3, 4]).astype('float32') + y_np = np.array([1, 5, 2]).astype('float32') + + x = paddle.static.data(name="x", shape=[3], dtype='float32') + y = paddle.static.data(name="y", shape=[3], dtype='float32') + + x_reshape = paddle.reshape(x, [3, 1]) + y_reshape = paddle.reshape(y, [3, 1]) + z = paddle.add(x_reshape, y_reshape) + z = paddle.reshape(z, shape=[3]) + + place = paddle.NPUPlace(0) + exe = paddle.static.Executor(place) + x_value, y_value, z_value = exe.run(feed={"x": x_np, + "y": y_np}, + fetch_list=[x, y, z]) + + z_expected = np.array([3., 8., 6.]) + self.assertEqual( + (x_value == x_np).all(), + True, + msg="x_value = {}, but expected {}".format(x_value, x_np)) + self.assertEqual( + (y_value == y_np).all(), + True, + msg="y_value = {}, but expected {}".format(y_value, y_np)) + self.assertEqual( + (z_value == z_expected).all(), + True, + msg="z_value = {}, but expected {}".format(z_value, z_expected)) + + def test_backward(self): + # TODO(ascendrc): Test backward after add grad npu op implemented. + pass + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestAddError(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # the input of elementwise_add must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)) + y1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)) + self.assertRaises(TypeError, paddle.add, x1, y1) + + # the input dtype must be float16 or float32 or float64 or int32 or int64 + x2 = paddle.static.data( + name='x2', shape=[3, 4, 5, 6], dtype="uint8") + y2 = paddle.static.data( + name='y2', shape=[3, 4, 5, 6], dtype="uint8") + self.assertRaises(TypeError, paddle.add, x2, y2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6c7b46f49f2725b646202998095adef3a65e63 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestElementwiseSubOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "elementwise_sub" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float32 + + def init_axis(self): + self.axis = 0 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + # TODO(ascendrc): For grad tests, OpTest raises FatalError:Segmentation fault + # when call op.run, which may be caused by system environment exception + # and the exact cause has not be located. + # def test_check_grad_normal(self): + # self.check_grad_with_place( + # self.place, ['X', 'Y'], + # 'Out', + # max_relative_error=0.006, + # check_dygraph=False) + # + # def test_check_grad_ingore_x(self): + # self.check_grad_with_place( + # self.place, ['Y'], + # 'Out', + # no_grad_set=set("X"), + # max_relative_error=0.006, + # check_dygraph=False) + # + # def test_check_grad_ingore_y(self): + # self.check_grad_with_place( + # self.place, ['X'], + # 'Out', + # no_grad_set=set("Y"), + # max_relative_error=0.006,check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSubtractAPI(unittest.TestCase): + def test_name(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[2, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[2, 3], dtype='float32') + + y_1 = paddle.subtract(x, y, name='add_res') + self.assertEqual(('add_res' in y_1.name), True) + + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + + x_np = np.array([2, 3, 4]).astype('float32') + y_np = np.array([1, 5, 2]).astype('float32') + + x = paddle.static.data(name="x", shape=[3], dtype='float32') + y = paddle.static.data(name="y", shape=[3], dtype='float32') + + x_reshape = paddle.reshape(x, [3, 1]) + y_reshape = paddle.reshape(y, [3, 1]) + z = paddle.subtract(x_reshape, y_reshape) + z = paddle.reshape(z, shape=[3]) + + place = paddle.NPUPlace(0) + exe = paddle.static.Executor(place) + x_value, y_value, z_value = exe.run(feed={"x": x_np, + "y": y_np}, + fetch_list=[x, y, z]) + + z_expected = np.array([1., -2., 2.]) + self.assertEqual( + (x_value == x_np).all(), + True, + msg="x_value = {}, but expected {}".format(x_value, x_np)) + self.assertEqual( + (y_value == y_np).all(), + True, + msg="y_value = {}, but expected {}".format(y_value, y_np)) + self.assertEqual( + (z_value == z_expected).all(), + True, + msg="z_value = {}, but expected {}".format(z_value, z_expected)) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSubtractError(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # the input of elementwise_add must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)) + y1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)) + self.assertRaises(TypeError, paddle.subtract, x1, y1) + + # the input dtype must be float16 or float32 or float64 or int32 or int64 + x2 = paddle.static.data( + name='x2', shape=[3, 4, 5, 6], dtype="uint8") + y2 = paddle.static.data( + name='y2', shape=[3, 4, 5, 6], dtype="uint8") + self.assertRaises(TypeError, paddle.subtract, x2, y2) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSubtractNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + c = paddle.assign(b) + z = paddle.subtract(sum, c) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + npu_pred, npu_loss = self._test(True) + cpu_pred, cpu_loos = self._test(False) + + self.assertTrue(np.allclose(npu_pred, cpu_pred)) + self.assertTrue(np.allclose(npu_loss, cpu_loos)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_npu_place.py b/python/paddle/fluid/tests/unittests/npu/test_npu_place.py new file mode 100644 index 0000000000000000000000000000000000000000..3f71fad2b9c1084148d8b0a28e556cc0bf5f366e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_npu_place.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +from paddle.fluid import core + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNpuPlace(unittest.TestCase): + def test(self): + p = core.Place() + p.set_place(paddle.NPUPlace(0)) + + self.assertTrue(p.is_npu_place()) + self.assertEqual(p.npu_device_id(), 0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNpuPlaceError(unittest.TestCase): + def test_static(self): + # NPU is not supported in ParallelExecutor + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + + x_np = np.array([2, 3, 4]).astype('float32') + y_np = np.array([1, 5, 2]).astype('float32') + + x = paddle.static.data(name="x", shape=[3], dtype='float32') + y = paddle.static.data(name="y", shape=[3], dtype='float32') + z = paddle.add(x, y) + + compiled_prog = paddle.static.CompiledProgram(prog) + place = paddle.NPUPlace(0) + exe = paddle.static.Executor(place) + + with self.assertRaisesRegex(RuntimeError, + "NPU is not supported in ParallelExecutor"): + exe.run(compiled_prog, feed={"x": x_np, "y": y_np}) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index dff96a8cbc3c4a4bd885fcd63f3314c98c7b465d..569c4316880df6f148880d59a8110934cb93e234 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -266,7 +266,10 @@ class OpTest(unittest.TestCase): np.random.seed(123) random.seed(124) - cls._use_system_allocator = _set_use_system_allocator(True) + if paddle.is_compiled_with_npu(): + cls._use_system_allocator = _set_use_system_allocator(False) + else: + cls._use_system_allocator = _set_use_system_allocator(True) @classmethod def tearDownClass(cls): @@ -298,6 +301,9 @@ class OpTest(unittest.TestCase): def is_rocm_op_test(): return core.is_compiled_with_rocm() + def is_npu_op_test(): + return hasattr(cls, "use_npu") and cls.use_npu == True + if not hasattr(cls, "op_type"): raise AssertionError( "This test do not have op_type in class attrs, " @@ -319,7 +325,8 @@ class OpTest(unittest.TestCase): and not hasattr(cls, 'exist_fp64_check_grad') \ and not is_xpu_op_test() \ and not is_mkldnn_op_test() \ - and not is_rocm_op_test(): + and not is_rocm_op_test() \ + and not is_npu_op_test(): raise AssertionError( "This test of %s op needs check_grad with fp64 precision." % cls.op_type) @@ -1216,7 +1223,8 @@ class OpTest(unittest.TestCase): # Check inplace for given op, its grad op, its grad_grad op, etc. # No effect on original OpTest # Currently not support ParallelExecutor on XPUPlace. - if not paddle.is_compiled_with_xpu(): + if not paddle.is_compiled_with_xpu( + ) and not paddle.is_compiled_with_npu(): self.check_inplace_output_with_place( place, no_check_set=no_check_set, inplace_atol=inplace_atol) diff --git a/python/paddle/fluid/tests/unittests/test_device.py b/python/paddle/fluid/tests/unittests/test_device.py index 195337e80defa930a67d6d9e08dc585d07cdb6fa..08697a080445e606f17bdde83384eef391713721 100644 --- a/python/paddle/fluid/tests/unittests/test_device.py +++ b/python/paddle/fluid/tests/unittests/test_device.py @@ -15,54 +15,39 @@ from __future__ import print_function import unittest -from op_test import OpTest -import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.framework as framework -import warnings -import paddle class TestStaticDeviceManage(unittest.TestCase): - def test_cpu_device(self): - paddle.set_device('cpu') + def _test_device(self, device_name, device_class): + paddle.set_device(device_name) + out1 = paddle.zeros(shape=[1, 3], dtype='float32') out2 = paddle.ones(shape=[1, 3], dtype='float32') out3 = paddle.concat(x=[out1, out2], axis=0) - exe = paddle.fluid.Executor() + + exe = paddle.static.Executor() exe.run(paddle.fluid.default_startup_program()) res = exe.run(fetch_list=[out3]) + device = paddle.get_device() - self.assertEqual(isinstance(exe.place, core.CPUPlace), True) - self.assertEqual(device, "cpu") + self.assertEqual(isinstance(exe.place, device_class), True) + self.assertEqual(device, device_name) + + def test_cpu_device(self): + self._test_device("cpu", core.CPUPlace) def test_gpu_device(self): if core.is_compiled_with_cuda(): - out1 = paddle.zeros(shape=[1, 3], dtype='float32') - out2 = paddle.ones(shape=[1, 3], dtype='float32') - out3 = paddle.concat(x=[out1, out2], axis=0) - paddle.set_device('gpu:0') - exe = paddle.fluid.Executor() - exe.run(paddle.fluid.default_startup_program()) - res = exe.run(fetch_list=[out3]) - device = paddle.get_device() - self.assertEqual(isinstance(exe.place, core.CUDAPlace), True) - self.assertEqual(device, "gpu:0") + self._test_device("gpu:0", core.CUDAPlace) def test_xpu_device(self): if core.is_compiled_with_xpu(): - out1 = paddle.zeros(shape=[1, 3], dtype='float32') - out2 = paddle.ones(shape=[1, 3], dtype='float32') - out3 = paddle.concat(x=[out1, out2], axis=0) - paddle.set_device('xpu:0') - exe = paddle.fluid.Executor() - exe.run(paddle.fluid.default_startup_program()) - res = exe.run(fetch_list=[out3]) - device = paddle.get_device() - self.assertEqual(isinstance(exe.place, core.XPUPlace), True) - self.assertEqual(device, "xpu:0") + self._test_device("xpu:0", core.XPUPlace) class TestImperativeDeviceManage(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_launch_ascend2.sh b/python/paddle/fluid/tests/unittests/test_fleet_launch_ascend2.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e9c1e6995399e94e43c54168bec8d86533ab2ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_launch_ascend2.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +RANK_TABLE_FILE_NAME="rank_table_file.json" +cat > ${RANK_TABLE_FILE_NAME} <