From 7e049108c547f0b07a59546e967778940192c04e Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 9 Feb 2021 16:58:40 +0800 Subject: [PATCH] [feature] support npu operator (#30951) [feature] support npu operator --- cmake/operators.cmake | 23 +- paddle/fluid/framework/library_type.h | 2 + paddle/fluid/framework/op_registry.h | 3 + paddle/fluid/framework/operator.cc | 10 + paddle/fluid/framework/tensor_test.cc | 82 +++++- paddle/fluid/framework/tensor_util.h | 16 ++ paddle/fluid/memory/memcpy.cc | 79 ++++++ paddle/fluid/memory/memcpy.h | 20 ++ paddle/fluid/operators/CMakeLists.txt | 5 + .../operators/elementwise/CMakeLists.txt | 1 + .../elementwise/elementwise_add_op_npu.cc | 58 +++++ .../elementwise_add_op_npu_test.cc | 83 +++++++ paddle/fluid/operators/npu_op_runner.cc | 235 ++++++++++++++++++ paddle/fluid/operators/npu_op_runner.h | 83 +++++++ paddle/fluid/platform/CMakeLists.txt | 15 +- paddle/fluid/platform/device_context.cc | 14 +- paddle/fluid/platform/device_context.h | 3 + paddle/fluid/platform/init.cc | 3 + paddle/fluid/platform/stream/CMakeLists.txt | 6 +- paddle/fluid/platform/stream/npu_stream.cc | 52 ++++ paddle/fluid/platform/stream/npu_stream.h | 76 ++++++ .../fluid/platform/stream_callback_manager.cc | 34 ++- .../fluid/platform/stream_callback_manager.h | 7 +- 23 files changed, 888 insertions(+), 22 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc create mode 100644 paddle/fluid/operators/npu_op_runner.cc create mode 100644 paddle/fluid/operators/npu_op_runner.h create mode 100644 paddle/fluid/platform/stream/npu_stream.cc create mode 100644 paddle/fluid/platform/stream/npu_stream.h diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 824daf77519..3fc75c3031f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -11,12 +11,16 @@ function(op_library TARGET) set(miopen_hip_cc_srcs) set(cu_cc_srcs) set(xpu_cc_srcs) + set(npu_cc_srcs) set(cudnn_cu_cc_srcs) set(cudnn_cu_srcs) set(CUDNN_FILE) set(mkldnn_cc_srcs) set(MKLDNN_FILE) set(op_common_deps operator op_registry math_function layer common_infer_shape_functions) + if (WITH_ASCEND_CL) + set(op_common_deps ${op_common_deps} npu_op_runner) + endif() # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build. set(options UNITY) set(oneValueArgs "") @@ -84,6 +88,12 @@ function(op_library TARGET) list(APPEND xpu_cc_srcs ${XPU_FILE}.cc) endif() endif() + if(WITH_ASCEND_CL) + string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc) + list(APPEND npu_cc_srcs ${NPU_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if (WITH_ROCM_PLATFORM AND ${src} MATCHES ".*\\.hip.cu$") @@ -106,6 +116,8 @@ function(op_library TARGET) list(APPEND cu_cc_srcs ${src}) elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") list(APPEND xpu_cc_srcs ${src}) + elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$") + list(APPEND npu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) else() @@ -170,7 +182,7 @@ function(op_library TARGET) # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`. if(WITH_UNITY_BUILD AND op_library_UNITY) # Combine the cc source files. - compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs}) + compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs}) if(TARGET ${UNITY_TARGET}) # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`. target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources}) @@ -181,7 +193,7 @@ function(op_library TARGET) # Add alias library to handle dependencies. add_library(${TARGET} ALIAS ${UNITY_TARGET}) else() - cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS} + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() endif() @@ -230,10 +242,11 @@ function(op_library TARGET) list(LENGTH cu_cc_srcs cu_cc_srcs_len) list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) + list(LENGTH npu_cc_srcs npu_cc_srcs_len) list(LENGTH hip_cu_srcs hip_cu_srcs_len) list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len) if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND - ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0) + ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -273,6 +286,9 @@ function(op_library TARGET) if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n") endif() + if (WITH_XPU AND ${npu_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, NPU);\n") + endif() # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) # Append first implemented MKLDNN activation operator @@ -323,6 +339,7 @@ function(register_operators) file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") string(REPLACE "_mkldnn" "" OPS "${OPS}") string(REPLACE "_xpu" "" OPS "${OPS}") + string(REPLACE "_npu" "" OPS "${OPS}") string(REPLACE ".cc" "" OPS "${OPS}") list(REMOVE_DUPLICATES OPS) list(LENGTH register_operators_DEPS register_operators_DEPS_len) diff --git a/paddle/fluid/framework/library_type.h b/paddle/fluid/framework/library_type.h index 4307e51862d..8fe314cf5f1 100644 --- a/paddle/fluid/framework/library_type.h +++ b/paddle/fluid/framework/library_type.h @@ -61,6 +61,8 @@ inline LibraryType StringToLibraryType(const char* ctype) { return LibraryType::kPlain; } else if (s == std::string("XPU")) { return LibraryType::kPlain; + } else if (s == std::string("NPU")) { + return LibraryType::kPlain; } else if (s == std::string("CUDA")) { return LibraryType::kPlain; } else { diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index e32ab8c7442..6975dd7a214 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -304,6 +304,9 @@ struct OpKernelRegistrarFunctorEx(framework::make_ddim({1, 2, 3}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); auto p1_holder = src_tensor.Holder(); EXPECT_NE(p1, nullptr); // set src_tensor a new dim with large size // momery is supposed to be re-allocated p2 = src_tensor.mutable_data(framework::make_ddim({3, 1024}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); auto p2_holder = src_tensor.Holder(); EXPECT_NE(p2, nullptr); EXPECT_NE(p1_holder.get(), p2_holder.get()); // set src_tensor a new dim with same size // momery block is supposed to be unchanged p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); EXPECT_EQ(p1, p2); // set src_tensor a new dim with smaller size // momery block is supposed to be unchanged p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + EXPECT_EQ(p1, p2); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + float* p1 = nullptr; + float* p2 = nullptr; + // initialization + p1 = src_tensor.mutable_data(framework::make_ddim({1, 2, 3}), + platform::NPUPlace(0)); + auto p1_holder = src_tensor.Holder(); + EXPECT_NE(p1, nullptr); + // set src_tensor a new dim with large size + // momery is supposed to be re-allocated + p2 = src_tensor.mutable_data(framework::make_ddim({3, 1024}), + platform::NPUPlace(0)); + auto p2_holder = src_tensor.Holder(); + EXPECT_NE(p2, nullptr); + EXPECT_NE(p1_holder.get(), p2_holder.get()); + // set src_tensor a new dim with same size + // momery block is supposed to be unchanged + p1 = src_tensor.mutable_data(framework::make_ddim({2, 2, 3}), + platform::NPUPlace(0)); + EXPECT_EQ(p1, p2); + // set src_tensor a new dim with smaller size + // momery block is supposed to be unchanged + p2 = src_tensor.mutable_data(framework::make_ddim({2, 2}), + platform::NPUPlace(0)); EXPECT_EQ(p1, p2); } #endif @@ -179,7 +208,17 @@ TEST(Tensor, ShareDataWith) { framework::Tensor src_tensor; framework::Tensor dst_tensor; src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + dst_tensor.ShareDataWith(src_tensor); + ASSERT_EQ(src_tensor.data(), dst_tensor.data()); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + framework::Tensor dst_tensor; + src_tensor.mutable_data(framework::make_ddim({2, 3, 4}), + platform::NPUPlace(0)); dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -216,7 +255,34 @@ TEST(Tensor, Slice) { { framework::Tensor src_tensor; src_tensor.mutable_data(framework::make_ddim({6, 9}), - platform::CUDAPlace()); + platform::CUDAPlace(0)); + framework::Tensor slice_tensor = src_tensor.Slice(2, 6); + framework::DDim slice_dims = slice_tensor.dims(); + ASSERT_EQ(arity(slice_dims), 2); + EXPECT_EQ(slice_dims[0], 4); + EXPECT_EQ(slice_dims[1], 9); + + uintptr_t src_data_address = + reinterpret_cast(src_tensor.data()); + uintptr_t src_mutable_data_address = + reinterpret_cast(src_tensor.mutable_data( + src_tensor.dims(), platform::CUDAPlace(0))); + uintptr_t slice_data_address = + reinterpret_cast(slice_tensor.data()); + uintptr_t slice_mutable_data_address = + reinterpret_cast(slice_tensor.mutable_data( + slice_tensor.dims(), platform::CUDAPlace(0))); + EXPECT_EQ(src_data_address, src_mutable_data_address); + EXPECT_EQ(slice_data_address, slice_mutable_data_address); + EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); + } +#endif + +#ifdef PADDLE_WITH_ASCEND_CL + { + framework::Tensor src_tensor; + src_tensor.mutable_data(framework::make_ddim({6, 9}), + platform::NPUPlace(0)); framework::Tensor slice_tensor = src_tensor.Slice(2, 6); framework::DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 2); @@ -227,12 +293,12 @@ TEST(Tensor, Slice) { reinterpret_cast(src_tensor.data()); uintptr_t src_mutable_data_address = reinterpret_cast(src_tensor.mutable_data( - src_tensor.dims(), platform::CUDAPlace())); + src_tensor.dims(), platform::NPUPlace(0))); uintptr_t slice_data_address = reinterpret_cast(slice_tensor.data()); uintptr_t slice_mutable_data_address = reinterpret_cast(slice_tensor.mutable_data( - slice_tensor.dims(), platform::CUDAPlace())); + slice_tensor.dims(), platform::NPUPlace(0))); EXPECT_EQ(src_data_address, src_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 50644370bc6..e782963f188 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -158,6 +158,14 @@ void TensorFromVector(const std::vector& src, reinterpret_cast(ctx).stream()); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(dst_place)) { // NOLINT + memory::Copy( + BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place, + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } template @@ -195,6 +203,14 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, reinterpret_cast(ctx).stream()); } #endif +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(src.place())) { // NOLINT + memory::Copy( + dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()), + src_ptr, size, + reinterpret_cast(ctx).stream()); + } +#endif } template diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index b17da7f69a9..22dd7eb48a4 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -198,6 +198,85 @@ void Copy(platform::XPUPlace dst_place, } #endif +#ifdef PADDLE_WITH_ASCEND_CL +template <> +void Copy(platform::NPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetNPUDeviceId(dst_place.device); + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by thream(" << stream << ")"; + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream); + } else { + platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE); + } +} + +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::NPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + platform::SetNPUDeviceId(src_place.device); + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by thream(" << stream << ")"; + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream); + } else { + platform::RecordEvent record_event("GpuMemcpySync:NPU->CPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST); + } +} + +template <> +void Copy(platform::NPUPlace dst_place, + void* dst, + platform::NPUPlace src_place, + const void* src, size_t num, + aclrtStream stream) { + if (UNLIKELY(num == 0)) return; + + VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " + << dst_place << " by stream(" << stream << ")"; + if (dst_place == src_place) { + platform::SetNPUDeviceId(src_place.device); + if (stream) { + platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE, + stream); + } else { + platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } else { + if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) { + PADDLE_THROW(platform::errors::Unavailable( + "Peer access between NPU places is not allowed.")); + } + if (stream) { + // TODO(zhiqiu): support peer access? + platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU"); + platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE, + stream); + } else { + platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU"); + platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } +} +#endif + #ifdef PADDLE_WITH_CUDA static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K diff --git a/paddle/fluid/memory/memcpy.h b/paddle/fluid/memory/memcpy.h index 7b2b8eb0662..bff4f8cfd7d 100644 --- a/paddle/fluid/memory/memcpy.h +++ b/paddle/fluid/memory/memcpy.h @@ -52,6 +52,26 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + +/** + * \brief Copy memory from one place to another place. + * + * \param[in] DstPlace Destination allocation place (CPU or NPU). + * \param[in] dst Destination memory address. + * \param[in] SrcPlace Source allocation place (CPU or NPU). + * \param[in] src Source memory address. + * \param[in] num memory size in bytes to copy. + * \param[in] stream NPU stream. + * + * \note For NPU memory copy, NPU stream need to be specified + * for asynchronously memory copy. + * + */ +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, + aclrtStream stream); #endif } // namespace memory diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f46320acf16..17234edb116 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -119,6 +119,11 @@ if (WITH_ASCEND) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} ascend_wrapper) endif() +if (WITH_ASCEND_CL) + cc_library(npu_op_runner SRCS npu_op_runner.cc DEPS operator npu_info) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner) +endif() + # FIXME(typhoonzero): operator deps may not needed. # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(array_to_lod_tensor_op DEPS lod_rank_table_op) diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt index 06ca98e526e..d3f7290aada 100644 --- a/paddle/fluid/operators/elementwise/CMakeLists.txt +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -8,3 +8,4 @@ register_operators(DEPS op_version_registry) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) +cc_test(elementwise_add_op_npu_test SRCS elementwise_add_op_npu_test.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc new file mode 100644 index 00000000000..6e48b84c9c6 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class ElementwiseAddNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + // TODO(zhiqiu): get the attr infomation of Ascend op and + // convert paddle AttributeMap to Ascend attrs. + // Ascend op add has no attribute ? + // int axis = ctx.Attr("axis"); + + // NOTE(zhiqiu): the order of inputs and outputs is important + auto runner = NpuOpRunner("Add", {*x, *y}, {*out}, {}); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + elementwise_add, + ops::ElementwiseAddNPUKernel); +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc new file mode 100644 index 00000000000..64915ef394d --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu_test.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(elementwise_add); +USE_OP_DEVICE_KERNEL(elementwise_add, NPU); + +void Compare(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + tensor_x->Resize({10, 10}); + + auto y = scope->Var("Y"); + auto tensor_y = y->GetMutable(); + tensor_y->Resize({10, 10}); + + std::vector init; + for (int64_t i = 0; i < 10 * 10; ++i) { + init.push_back(1.0); + } + + TensorFromVector(init, ctx, tensor_x); + TensorFromVector(init, ctx, tensor_y); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + tensor_out->Resize({10, 10}); + tensor_out->mutable_data(place); // allocate + + // run + f::AttributeMap attrs; + auto op = + f::OpRegistry::CreateOp("elementwise_add", {{"X", {"X"}}, {"Y", {"Y"}}}, + {{"Out", {"Out"}}}, attrs); + + op->Run(*scope, place); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + EXPECT_EQ(out_vec.size(), init.size()); + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_EQ(out_vec[i], 2.0); + } +} + +TEST(elementwise_add, NPU) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx); +} diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc new file mode 100644 index 00000000000..7eb0ff68e61 --- /dev/null +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -0,0 +1,235 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/npu_op_runner.h" + +#include +#include + +#include +#include +#include + +#include "acl/acl.h" +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace operators { + +static std::map DTYPE_2_ACL_DTYPE = { + {framework::proto::VarType::BOOL, ACL_BOOL}, {framework::proto::VarType::INT16, ACL_INT16}, + {framework::proto::VarType::INT32, ACL_INT32}, {framework::proto::VarType::INT64, ACL_INT64}, + {framework::proto::VarType::FP16, ACL_FLOAT16}, {framework::proto::VarType::FP32, ACL_FLOAT}, + {framework::proto::VarType::FP64, ACL_DOUBLE}, +}; + +static std::map DATA_LAYOUT_2_ACL_FORMAT = { + {DataLayout::kNCHW, ACL_FORMAT_NCHW}, + {DataLayout::kNHWC, ACL_FORMAT_NHWC}, + {DataLayout::kAnyLayout, ACL_FORMAT_ND}, +}; + +aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype) { + auto iter = DTYPE_2_ACL_DTYPE.find(dtype); + PADDLE_ENFORCE_NE(iter, DTYPE_2_ACL_DTYPE.end(), + platform::errors::NotFound( + "The data type (%s) can not convert to ACL data type.", + framework::DataTypeToString(dtype))); + return iter->second; +} + +aclFormat ConvertToNpuFormat(DataLayout layout) { + auto iter = DATA_LAYOUT_2_ACL_FORMAT.find(layout); + PADDLE_ENFORCE_NE( + iter, DATA_LAYOUT_2_ACL_FORMAT.end(), + platform::errors::NotFound( + "The data type (%s) can not convert to ACL data type.", layout)); + return iter->second; +} + +NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {} +NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) + : op_type_(op_type) { + AddInputs(inputs); + AddOutputs(outputs); + AddAttrs(attrs); +} + +NpuOpRunner::~NpuOpRunner() { + //TODO(zhiqiu): handle free +} + +const std::string &NpuOpRunner::Type() { return op_type_; } + +NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name, + const Attribute &attr) { + if (attr.type() == typeid(bool)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr))); + } else if (attr.type() == typeid(int)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrInt(attr_, name.c_str(), BOOST_GET_CONST(int, attr))); + + } else if (attr.type() == typeid(int64_t)) { + PADDLE_ENFORCE_NPU_SUCCESS(aclopSetAttrInt( + attr_, name.c_str(), BOOST_GET_CONST(int64_t, attr))); + } else if (attr.type() == typeid(float)) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrFloat(attr_, name.c_str(), BOOST_GET_CONST(float, attr))); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector cast_a; + for(auto it : a) { + cast_a.push_back(static_cast(it)); + } + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListBool(attr_, name.c_str(), cast_a.size(), cast_a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector cast_a; + for(auto it : a) { + cast_a.push_back(static_cast(it)); + } + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListInt(attr_, name.c_str(), cast_a.size(), cast_a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListInt(attr_, name.c_str(), a.size(), a.data())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListFloat(attr_, name.c_str(), a.size(), a.data())); + } else if (attr.type() == typeid(std::string)) { + auto a = BOOST_GET_CONST(std::string, attr); + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrString(attr_, name.c_str(), a.c_str())); + } else if (attr.type() == typeid(std::vector)) { + auto a = BOOST_GET_CONST(std::vector, attr); + std::vector s; + for (auto &it : a) { + s.push_back(it.data()); + } + PADDLE_ENFORCE_NPU_SUCCESS( + aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data())); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Can not convert attribubte '%s' to convert to aclopAttr", name)); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddAttrs(const AttributeMap &attrs) { + for (const auto &pair : attrs) { + AddAttr(pair.first, pair.second); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddInput(const Tensor &tensor) { + // create aclTensorDesc + input_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + input_buffers_.emplace_back(CreateDataBuffer(tensor)); + return *this; +} + +NpuOpRunner &NpuOpRunner::AddOutput(const Tensor &tensor) { + // create aclTensorDesc + output_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + output_buffers_.emplace_back(CreateDataBuffer(tensor)); + return *this; +} + +NpuOpRunner &NpuOpRunner::AddInputs(const std::vector &tensors) { + for (auto tensor : tensors) { + // create aclTensorDesc + input_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + input_buffers_.emplace_back(CreateDataBuffer(tensor)); + } + return *this; +} + +NpuOpRunner &NpuOpRunner::AddOutputs(const std::vector &tensors) { + for (auto tensor : tensors) { + // create aclTensorDesc + output_descs_.emplace_back(CreateTensorDesc(tensor)); + // create aclDataBuffer + output_buffers_.emplace_back(CreateDataBuffer(tensor)); + } + return *this; +} + +aclTensorDesc *NpuOpRunner::GetInputDesc(size_t index) { + PADDLE_ENFORCE_LT(index, input_descs_.size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + Type(), index, input_descs_.size())); + return input_descs_[index]; +} + +aclTensorDesc *NpuOpRunner::GetOutputDesc(size_t index) { + PADDLE_ENFORCE_LT(index, output_descs_.size(), + platform::errors::OutOfRange( + "The index should be less than the size of output of " + "operator %s, but got index is %d and size is %d", + Type(), index, output_descs_.size())); + return output_descs_[index]; +} + +std::vector &NpuOpRunner::GetInputDescs() { + return input_descs_; +} + +std::vector &NpuOpRunner::GetOutputDescs() { + return output_descs_; +} + +std::vector &NpuOpRunner::GetInputBuffers() { return input_buffers_; } + +std::vector &NpuOpRunner::GetOutputBuffers() { return output_buffers_; } + +aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { + auto dtype = ConvertToNpuDtype(tensor.type()); + auto format = ConvertToNpuFormat(tensor.layout()); + auto dims = framework::vectorize(tensor.dims()); + + auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format); + PADDLE_ENFORCE_NOT_NULL( + desc, platform::errors::External("Call aclCreateTensorDesc failed.")); + return desc; +} + +aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) { + auto *buffer = + aclCreateDataBuffer(tensor.Holder()->ptr(), tensor.memory_size()); + PADDLE_ENFORCE_NOT_NULL( + buffer, platform::errors::External("Call aclCreateDataBuffer failed.")); + return buffer; +} + +void NpuOpRunner::Run(aclrtStream stream) { + aclError ret = aclopExecuteV2(op_type_.c_str(), input_descs_.size(), + input_descs_.data(), input_buffers_.data(), + output_descs_.size(), output_descs_.data(), + output_buffers_.data(), attr_, stream); + PADDLE_ENFORCE_NPU_SUCCESS(ret); +} +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/npu_op_runner.h b/paddle/fluid/operators/npu_op_runner.h new file mode 100644 index 00000000000..2e68226ed07 --- /dev/null +++ b/paddle/fluid/operators/npu_op_runner.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include + +#include "acl/acl.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +using Attribute = framework::Attribute; +using AttributeMap = framework::AttributeMap; + +class NpuOpRunner { + public: + explicit NpuOpRunner(std::string op_type); + explicit NpuOpRunner(std::string op_type, const std::vector &inputs = {}, + const std::vector &outputs = {}, + const AttributeMap &attrs = {}); + + ~NpuOpRunner(); + + const std::string &Type(); + + NpuOpRunner &AddAttr(const std::string& name, const Attribute &attr); + + NpuOpRunner &AddAttrs(const AttributeMap &attrs); + + NpuOpRunner &AddInput(const Tensor &tensor); + + NpuOpRunner &AddOutput(const Tensor &tensor); + + NpuOpRunner &AddInputs(const std::vector &tensors); + + NpuOpRunner &AddOutputs(const std::vector &tensors); + + aclTensorDesc *GetInputDesc(size_t index); + + aclTensorDesc *GetOutputDesc(size_t index); + + std::vector &GetInputDescs(); + + std::vector &GetOutputDescs(); + + std::vector &GetInputBuffers(); + + std::vector &GetOutputBuffers(); + + void Run(aclrtStream stream); + + private: + aclTensorDesc *CreateTensorDesc(Tensor tensor); + aclDataBuffer *CreateDataBuffer(Tensor tensor); + + private: + std::string op_type_; + std::vector input_buffers_; + std::vector output_buffers_; + std::vector input_descs_; + std::vector output_descs_; + aclopAttr *attr_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index f3331349fde..27389c4fd65 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -90,15 +90,28 @@ IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader cuda_stream) ENDIF() +IF(WITH_ASCEND_CL) + set(NPU_CTX_DEPS npu_stream npu_info) +ENDIF() + IF(WITH_MKLDNN) set(MKLDNN_CTX_DEPS mkldnn) ELSE() set(MKLDNN_CTX_DEPS) ENDIF() +IF(WITH_GPU) nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) +ENDIF() + +IF(WITH_ASCEND_CL) +cc_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce atlas_acl) +ENDIF() + IF(WITH_GPU) set(STREAM_CALLBACK_DEPS stream_callback_manager) +ELSEIF(WITH_ASCEND_CL) + set(STREAM_CALLBACK_DEPS stream_callback_manager) ELSE() set(STREAM_CALLBACK_DEPS) ENDIF() @@ -112,7 +125,7 @@ cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} - place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS}) cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 24182b837f1..79e606596f9 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -87,9 +87,8 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { if (it == device_contexts_.end()) { PADDLE_THROW(platform::errors::Unimplemented( "Place %s is not supported. Please check that your paddle compiles " - "with WITH_GPU or WITH_XPU option or check that your train process " - "hold the " - "correct gpu_id if you use Executor.", + "with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that " + "your train process set the correct device id if you use Executor.", place)); } return it->second.get().get(); @@ -150,6 +149,14 @@ DeviceContextPool::DeviceContextPool( PADDLE_THROW( platform::errors::Unimplemented("XPUPlace is not supported. Please " "re-compile with WITH_XPU option.")); +#endif + } else if (platform::is_npu_place(p)) { +#ifdef PADDLE_WITH_ASCEND_CL + EmplaceDeviceContext(&device_contexts_, p); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "NPUPlace is not supported. Please " + "re-compile with WITH_ASCEND_CL option.")); #endif } } @@ -679,6 +686,5 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( } #endif - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a4e584eeffa..0b4ac60c836 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -47,6 +47,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/stream/cuda_stream.h" #endif +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/platform/stream/npu_stream.h" +#endif #include "unsupported/Eigen/CXX11/Tensor" namespace Eigen { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 1c8b05768a4..45a31ad840f 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -181,6 +181,9 @@ void InitDevices(const std::vector devices) { #endif #ifdef PADDLE_WITH_XPU places.emplace_back(platform::XPUPlace(devices[i])); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + places.emplace_back(platform::NPUPlace(devices[i])); #endif } places.emplace_back(platform::CPUPlace()); diff --git a/paddle/fluid/platform/stream/CMakeLists.txt b/paddle/fluid/platform/stream/CMakeLists.txt index 78a7313bded..21e48b6d4ac 100644 --- a/paddle/fluid/platform/stream/CMakeLists.txt +++ b/paddle/fluid/platform/stream/CMakeLists.txt @@ -1,3 +1,7 @@ IF(WITH_GPU) -cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost) +cc_library(cuda_stream SRCS cuda_stream.cc DEPS enforce boost stream_callback_manager) +ENDIF() + +IF(WITH_ASCEND_CL) +cc_library(npu_stream SRCS npu_stream.cc DEPS enforce boost stream_callback_manager) ENDIF() diff --git a/paddle/fluid/platform/stream/npu_stream.cc b/paddle/fluid/platform/stream/npu_stream.cc new file mode 100644 index 00000000000..1a07a1ed837 --- /dev/null +++ b/paddle/fluid/platform/stream/npu_stream.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/stream/npu_stream.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/npu_info.h" + +namespace paddle { +namespace platform { +namespace stream { + +bool NPUStream::Init(const Place& place) { + PADDLE_ENFORCE_EQ(is_npu_place(place), true, + platform::errors::InvalidArgument( + "NPU stream must be created using npu place.")); + place_ = place; + NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtCreateStream(&stream_)); + callback_manager_.reset(new StreamCallbackManager(stream_)); + VLOG(3) << "NPUStream Init stream: " << stream_; + return true; +} + +void NPUStream::Destroy() { + NPUDeviceGuard guard(BOOST_GET_CONST(NPUPlace, place_).device); + Wait(); + WaitCallback(); + if (stream_) { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtDestroyStream(stream_)); + } + stream_ = nullptr; +} + +void NPUStream::Wait() const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_)); +} + + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream/npu_stream.h b/paddle/fluid/platform/stream/npu_stream.h new file mode 100644 index 00000000000..7e5d574acec --- /dev/null +++ b/paddle/fluid/platform/stream/npu_stream.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/npu_info.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/stream_callback_manager.h" + +namespace paddle { +namespace platform { +namespace stream { + +#ifdef PADDLE_WITH_ASCEND_CL + +class NPUStream final { + public: + NPUStream() = default; + explicit NPUStream(const Place& place) { Init(place); } + virtual ~NPUStream() { Destroy(); } + + bool Init(const Place& place); + + template + void AddCallback(Callback&& callback) const { + callback_manager_->AddCallback(callback); + } + + template + void RecordEvent(aclrtEvent ev, Callback callback) const { + callback(); + PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_)); + } + + void RecordEvent(aclrtEvent ev) const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtRecordEvent(ev, stream_)); + } + + void WaitEvent(aclrtEvent ev) const { + PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(stream_, ev)); + } + + void Wait() const; + void WaitCallback() const { callback_manager_->Wait(); } + + aclrtStream raw_stream() const { return stream_; } + void Destroy(); + + private: + Place place_; + aclrtStream stream_{nullptr}; + std::unique_ptr> callback_manager_; + + DISABLE_COPY_AND_ASSIGN(NPUStream); +}; + +#endif + +} // namespace stream +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index 365216566b2..76128e9a8f4 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -19,22 +19,30 @@ namespace paddle { namespace platform { +#ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 10000 static void CUDART_CB StreamCallbackFunc(void *user_data) #else static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, cudaError_t status, void *user_data) #endif +#endif + +#if PADDLE_WITH_ASCEND_CL +static void StreamCallbackFunc(void *user_data) +#endif { std::unique_ptr> func( reinterpret_cast *>(user_data)); (*func)(); } -StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream) +template +StreamCallbackManager::StreamCallbackManager(const Stream stream) : stream_(stream), thread_pool_(1) {} -void StreamCallbackManager::AddCallback(std::function callback) const { +template +void StreamCallbackManager::AddCallback(std::function callback) const { auto *callback_func = new std::function(std::move(callback)); auto *func = new std::function([this, callback_func] { std::lock_guard lock(mtx_); @@ -43,6 +51,7 @@ void StreamCallbackManager::AddCallback(std::function callback) const { (*callback_func)(); }); }); +#ifdef PADDLE_WITH_CUDA #if CUDA_VERSION >= 10000 PADDLE_ENFORCE_CUDA_SUCCESS( cudaLaunchHostFunc(stream_, StreamCallbackFunc, func)); @@ -50,10 +59,22 @@ void StreamCallbackManager::AddCallback(std::function callback) const { PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0)); #endif +#endif + +#if PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_NPU_SUCCESS(aclrtLaunchCallback(StreamCallbackFunc, func, + ACL_CALLBACK_BLOCK, stream_)); +#endif } -void StreamCallbackManager::Wait() const { +template +void StreamCallbackManager::Wait() const { +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_)); +#endif +#ifdef PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_)); +#endif { std::lock_guard lock(mtx_); if (last_future_.valid()) { @@ -62,5 +83,12 @@ void StreamCallbackManager::Wait() const { } } +#ifdef PADDLE_WITH_CUDA +template struct StreamCallbackManager; +#endif +#ifdef PADDLE_WITH_ASCEND_CL +template struct StreamCallbackManager; +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h index 8668bcb1131..4838e5969a8 100644 --- a/paddle/fluid/platform/stream_callback_manager.h +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -15,8 +15,10 @@ #pragma once #include +#ifdef PADDLE_WITH_CUDA #include #include +#endif #include #include // NOLINT #include @@ -29,9 +31,10 @@ namespace platform { // NOTE(zjl): clean StreamCallbackManager to make compilation faster // Make StreamCallbackManager thread-safe +template class StreamCallbackManager { public: - explicit StreamCallbackManager(const cudaStream_t stream); + explicit StreamCallbackManager(const Stream stream); ~StreamCallbackManager() = default; @@ -40,7 +43,7 @@ class StreamCallbackManager { void Wait() const; private: - const cudaStream_t stream_; + const Stream stream_; mutable ::ThreadPool thread_pool_; mutable std::mutex mtx_; mutable std::future last_future_; -- GitLab